aws_unlock/
aws_profile.rs

1use anyhow::{anyhow, bail, Result};
2use home::home_dir;
3use std::{
4    collections::{HashMap, HashSet},
5    fmt,
6    fs::{File, OpenOptions},
7    io::{Read, Seek, SeekFrom, Write},
8};
9
10use crate::{line_lexer::EntryLineLexer, line_parser::EntryLineParser};
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
13pub enum ProfileName {
14    Default,
15    Named(String),
16}
17
18impl<S> From<S> for ProfileName
19where
20    S: AsRef<str> + Into<String>,
21{
22    fn from(value: S) -> Self {
23        if value.as_ref() == "default" {
24            ProfileName::Default
25        } else {
26            ProfileName::Named(value.into())
27        }
28    }
29}
30
31impl fmt::Display for ProfileName {
32    fn fmt(&self, b: &mut fmt::Formatter) -> fmt::Result {
33        match self {
34            ProfileName::Default => write!(b, "default"),
35            ProfileName::Named(name) => write!(b, "{name}"),
36        }
37    }
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
41pub struct AwsProfile {
42    /// Comment lines in config file.
43    pub config_comments: Vec<String>,
44
45    /// Comment lines in credentials file.
46    pub credentials_comments: Vec<String>,
47
48    /// Whether this profile is for production environment or not.
49    pub is_production: bool,
50
51    /// Whether this profile is currently locked or not.
52    pub is_locked: bool,
53
54    /// The profile name. None if it is default profile.
55    pub name: ProfileName,
56
57    /// `region` in ~/.aws/config.
58    pub region: Option<String>,
59
60    /// `output` in ~/.aws/config.
61    pub output: Option<String>,
62
63    /// `aws_access_key_id` in ~/.aws/credentials.
64    pub aws_access_key_id: String,
65
66    /// `aws_secret_access_key` in ~/.aws/credentials.
67    pub aws_secret_access_key: String,
68}
69
70#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
71struct AwsConfig {
72    comments: Vec<String>,
73    is_production: bool,
74    is_locked: bool,
75    name: ProfileName,
76    region: Option<String>,
77    output: Option<String>,
78}
79
80#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
81struct AwsCredential {
82    comments: Vec<String>,
83    is_production: bool,
84    is_locked: bool,
85    name: ProfileName,
86    aws_access_key_id: String,
87    aws_secret_access_key: String,
88}
89
90#[derive(Debug)]
91pub struct AwsFile {
92    config: File,
93    credentials: File,
94}
95
96impl AwsFile {
97    pub fn open() -> Result<AwsFile> {
98        let home_dir = home_dir().expect("failed to locate home directory");
99
100        let aws_home = home_dir.join(".aws");
101        let config = OpenOptions::new()
102            .read(true)
103            .write(true)
104            .open(aws_home.join("config"))?;
105        let credentials = OpenOptions::new()
106            .read(true)
107            .write(true)
108            .open(aws_home.join("credentials"))?;
109
110        Ok(AwsFile {
111            config,
112            credentials,
113        })
114    }
115
116    pub fn flush(&mut self) -> Result<()> {
117        self.config.flush()?;
118        self.credentials.flush()?;
119
120        Ok(())
121    }
122
123    pub fn parse(&mut self) -> Result<Vec<AwsProfile>> {
124        let config = self.parse_config()?;
125        let config_names: Vec<_> = config.iter().map(|conf| conf.name.clone()).collect();
126        let mut config: HashMap<_, _> = config
127            .into_iter()
128            .map(|conf| (conf.name.clone(), conf))
129            .collect();
130
131        let credentials = self.parse_credentials()?;
132        let credentials_names: Vec<_> = credentials.iter().map(|cred| cred.name.clone()).collect();
133        let mut credentials: HashMap<_, _> = credentials
134            .into_iter()
135            .map(|cred| (cred.name.clone(), cred))
136            .collect();
137
138        let mut names = vec![];
139        let mut inserted = HashSet::new();
140        for name in config_names.iter().chain(&credentials_names) {
141            if inserted.insert(name) {
142                names.push(name);
143            }
144        }
145
146        names
147            .into_iter()
148            .map(|name| {
149                let conf = config
150                    .remove(name)
151                    .ok_or_else(|| anyhow!("config '{name}' not found",))?;
152                let cred = credentials
153                    .remove(name)
154                    .ok_or_else(|| anyhow!("credentials '{name}' not found",))?;
155
156                Ok(AwsProfile {
157                    config_comments: conf.comments,
158                    credentials_comments: cred.comments,
159                    is_production: conf.is_production || cred.is_production,
160                    is_locked: conf.is_locked || cred.is_locked,
161                    name: name.clone(),
162                    region: conf.region,
163                    output: conf.output,
164                    aws_access_key_id: cred.aws_access_key_id,
165                    aws_secret_access_key: cred.aws_secret_access_key,
166                })
167            })
168            .collect()
169    }
170
171    fn parse_config(&mut self) -> Result<Vec<AwsConfig>> {
172        let mut buf = String::new();
173        self.config.seek(SeekFrom::Start(0))?;
174        self.config.read_to_string(&mut buf)?;
175        let lexer = &mut EntryLineLexer::new(&buf);
176        let lines = lexer.tokenize()?;
177        let entries = EntryLineParser::new(lines).parse()?;
178
179        entries
180            .into_iter()
181            .map(|entry| {
182                let name = if entry.header == "default" {
183                    ProfileName::Default
184                } else {
185                    match *entry.header.splitn(2, ' ').collect::<Vec<_>>() {
186                        [lit_profile, name] if lit_profile == "profile" => name.into(),
187                        _ => bail!("unexpected header in your config: {:?}", entry.header),
188                    }
189                };
190                let region = entry.values.get("region").cloned();
191                let output = entry.values.get("output").cloned();
192                Ok(AwsConfig {
193                    comments: entry.comments,
194                    is_production: entry.is_production,
195                    is_locked: entry.is_locked,
196                    name,
197                    region,
198                    output,
199                })
200            })
201            .collect()
202    }
203
204    fn parse_credentials(&mut self) -> Result<Vec<AwsCredential>> {
205        let mut buf = String::new();
206        self.config.seek(SeekFrom::Start(0))?;
207        self.credentials.read_to_string(&mut buf)?;
208        let lexer = &mut EntryLineLexer::new(&buf);
209        let lines = lexer.tokenize()?;
210        let entries = EntryLineParser::new(lines).parse()?;
211
212        entries
213            .into_iter()
214            .map(|entry| {
215                let name = entry.header.into();
216                let aws_access_key_id = entry
217                    .values
218                    .get("aws_access_key_id")
219                    .ok_or_else(|| {
220                        anyhow!("failed to find 'aws_access_key_id' in your credentials")
221                    })?
222                    .to_string();
223                let aws_secret_access_key = entry
224                    .values
225                    .get("aws_secret_access_key")
226                    .ok_or_else(|| {
227                        anyhow!("failed to find 'aws_secret_access_key' in your credentials")
228                    })?
229                    .to_string();
230                Ok(AwsCredential {
231                    comments: entry.comments,
232                    is_production: entry.is_production,
233                    is_locked: entry.is_locked,
234                    name,
235                    aws_access_key_id,
236                    aws_secret_access_key,
237                })
238            })
239            .collect()
240    }
241
242    pub fn write(&mut self, profiles: &[AwsProfile]) -> Result<()> {
243        let config: Vec<_> = profiles
244            .iter()
245            .map(|profile| AwsConfig {
246                comments: profile.config_comments.clone(),
247                is_production: profile.is_production,
248                is_locked: profile.is_locked,
249                name: profile.name.clone(),
250                region: profile.region.clone(),
251                output: profile.output.clone(),
252            })
253            .collect();
254        let credentials: Vec<_> = profiles
255            .iter()
256            .map(|profile| AwsCredential {
257                comments: profile.credentials_comments.clone(),
258                is_production: profile.is_production,
259                is_locked: profile.is_locked,
260                name: profile.name.clone(),
261                aws_access_key_id: profile.aws_access_key_id.clone(),
262                aws_secret_access_key: profile.aws_secret_access_key.clone(),
263            })
264            .collect();
265        self.write_config(&config)?;
266        self.write_credentials(&credentials)?;
267
268        Ok(())
269    }
270
271    fn write_config(&mut self, config: &[AwsConfig]) -> Result<()> {
272        self.config.seek(SeekFrom::Start(0))?;
273        self.config.set_len(0)?;
274
275        let mut first = true;
276        for conf in config {
277            if !first {
278                writeln!(self.config)?;
279            }
280            first = false;
281
282            for comment in &conf.comments {
283                writeln!(self.config, "# {}", comment)?;
284            }
285
286            if conf.is_production {
287                writeln!(self.config, "# production")?;
288            }
289
290            let locked_prefix = if conf.is_locked { "# " } else { "" };
291
292            match &conf.name {
293                ProfileName::Named(name) => {
294                    writeln!(self.config, "{}[profile {}]", locked_prefix, name)?
295                }
296                ProfileName::Default => writeln!(self.config, "{}[default]", locked_prefix)?,
297            }
298
299            if let Some(region) = &conf.region {
300                writeln!(self.config, "{}region = {}", locked_prefix, region)?;
301            }
302
303            if let Some(output) = &conf.output {
304                writeln!(self.config, "{}output = {}", locked_prefix, output)?;
305            }
306        }
307
308        Ok(())
309    }
310
311    fn write_credentials(&mut self, credentials: &[AwsCredential]) -> Result<()> {
312        self.credentials.seek(SeekFrom::Start(0))?;
313        self.credentials.set_len(0)?;
314
315        let mut first = true;
316        for cred in credentials {
317            if !first {
318                writeln!(self.credentials)?;
319            }
320            first = false;
321
322            for comment in &cred.comments {
323                writeln!(self.credentials, "# {}", comment)?;
324            }
325
326            if cred.is_production {
327                writeln!(self.credentials, "# production")?;
328            }
329
330            let locked_prefix = if cred.is_locked { "# " } else { "" };
331
332            writeln!(self.credentials, "{}[{}]", locked_prefix, cred.name)?;
333            writeln!(
334                self.credentials,
335                "{}aws_access_key_id = {}",
336                locked_prefix, cred.aws_access_key_id
337            )?;
338            writeln!(
339                self.credentials,
340                "{}aws_secret_access_key = {}",
341                locked_prefix, cred.aws_secret_access_key
342            )?;
343        }
344
345        Ok(())
346    }
347}