node_discover/
args.rs

1use std::{
2    collections::{hash_map::IntoIter, HashMap},
3    convert::TryFrom,
4    fmt::Display,
5};
6
7use crate::errors::DiscoverError;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Deserialize, Serialize)]
11pub enum SupportedProvider {
12    #[serde(rename = "aws")]
13    AWS,
14    #[serde(rename = "digitalocean")]
15    DigitalOcean,
16}
17
18impl Display for SupportedProvider {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        let provider = serde_json::to_string(self).expect("To serialize supported provider name");
21        write!(f, "{}", provider)
22    }
23}
24
25/// A utility type for parsing and working with the CLI arguments
26#[derive(Debug, Clone)]
27pub struct ParsedArgs {
28    inner: HashMap<String, String>,
29    provider: SupportedProvider,
30}
31
32impl ParsedArgs {
33    pub fn get(&self, key: &str) -> Option<&String> {
34        self.inner.get(key)
35    }
36
37    pub fn provider(&self) -> &SupportedProvider {
38        &self.provider
39    }
40}
41
42impl IntoIterator for ParsedArgs {
43    type Item = (String, String);
44
45    type IntoIter = IntoIter<String, String>;
46
47    fn into_iter(self) -> Self::IntoIter {
48        self.inner.into_iter()
49    }
50}
51
52impl TryFrom<Vec<String>> for ParsedArgs {
53    type Error = DiscoverError;
54
55    fn try_from(value: Vec<String>) -> Result<Self, Self::Error> {
56        let mut args = HashMap::with_capacity(value.len());
57        for arg_str in value {
58            let arg = arg_str.splitn(2, '=').collect::<Vec<_>>();
59
60            if arg.len() != 2 || arg[1].is_empty() {
61                return Err(DiscoverError::MalformedArgument(
62                    arg[0].to_string(),
63                    "Expected an argument on the format: key=value".to_string(),
64                ));
65            }
66
67            // Fail on duplicate arg
68            if args
69                .insert(arg[0].to_string(), arg[1].to_string())
70                .is_some()
71            {
72                return Err(DiscoverError::DuplicateArgument(arg[0].to_string()));
73            }
74        }
75
76        let provider = match args.get("provider") {
77            // provider must always be provided
78            None => return Err(DiscoverError::MissingArgument("provider".into())),
79            Some(p) => match &p.to_lowercase()[..] {
80                "aws" => SupportedProvider::AWS,
81                "digitalocean" => SupportedProvider::DigitalOcean,
82                _ => return Err(DiscoverError::UnsupportedProvider(p.to_string())),
83            },
84        };
85
86        Ok(Self {
87            inner: args,
88            provider,
89        })
90    }
91}
92
93impl TryFrom<String> for ParsedArgs {
94    type Error = DiscoverError;
95
96    fn try_from(value: String) -> Result<Self, Self::Error> {
97        let args = value
98            .trim()
99            .split(' ')
100            .map(String::from)
101            .collect::<Vec<_>>();
102        ParsedArgs::try_from(args)
103    }
104}
105
106#[cfg(test)]
107mod test {
108    use super::*;
109
110    #[test]
111    fn fail_when_provider_is_not_provided() {
112        let tag_key = "Name";
113        let tag_value = "fsajfopja";
114        let addr_type = "private_v4";
115
116        let args = format!(
117            "region=eu-west-1 tag_key={} tag_value={} addr_type={}",
118            tag_key, tag_value, addr_type
119        );
120
121        let res = ParsedArgs::try_from(args);
122        assert!(res.is_err());
123        assert_eq!(
124            res.unwrap_err(),
125            DiscoverError::MissingArgument("provider".to_string())
126        );
127    }
128
129    #[test]
130    fn fail_on_duplicate_argument() {
131        let inputs = vec!["provider=aws provider=do", "provider=aws provider=aws"];
132
133        for input in inputs {
134            let res = ParsedArgs::try_from(input.to_string());
135            assert!(res.is_err());
136            assert_eq!(
137                res.unwrap_err(),
138                DiscoverError::DuplicateArgument("provider".to_string())
139            );
140        }
141    }
142
143    #[test]
144    fn fail_on_garbage_input() {
145        let inputs = vec!["", "!!", "?"];
146
147        for input in inputs {
148            let res = ParsedArgs::try_from(input.to_string());
149            assert!(res.is_err());
150            assert_eq!(
151                res.unwrap_err(),
152                DiscoverError::MalformedArgument(
153                    input.to_string(),
154                    "Expected an argument on the format: key=value".to_string(),
155                )
156            );
157        }
158    }
159
160    #[test]
161    fn fail_on_malformed_args() {
162        let malformed_args = vec!["=", "x:y", "zzzz", "t?x", "help=", "key"];
163
164        for malformed_arg in malformed_args {
165            let args = format!("provider=aws region=eu-west-1 {}", malformed_arg);
166            let res = ParsedArgs::try_from(args);
167            assert!(res.is_err());
168            if malformed_arg.ends_with("=") {
169                assert_eq!(
170                    res.unwrap_err(),
171                    DiscoverError::MalformedArgument(
172                        malformed_arg[..malformed_arg.len() - 1].to_string(),
173                        "Expected an argument on the format: key=value".to_string(),
174                    )
175                );
176            } else {
177                assert_eq!(
178                    res.unwrap_err(),
179                    DiscoverError::MalformedArgument(
180                        malformed_arg.to_string(),
181                        "Expected an argument on the format: key=value".to_string(),
182                    )
183                );
184            }
185        }
186    }
187}