1use anyhow::{anyhow, bail, Result};
2use home::home_dir;
3use std::{
4 collections::{HashMap, HashSet},
5 fmt,
6 fs::{File, OpenOptions},
7 io::{Read, Seek, SeekFrom, Write},
8};
9
10use crate::{line_lexer::EntryLineLexer, line_parser::EntryLineParser};
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
13pub enum ProfileName {
14 Default,
15 Named(String),
16}
17
18impl<S> From<S> for ProfileName
19where
20 S: AsRef<str> + Into<String>,
21{
22 fn from(value: S) -> Self {
23 if value.as_ref() == "default" {
24 ProfileName::Default
25 } else {
26 ProfileName::Named(value.into())
27 }
28 }
29}
30
31impl fmt::Display for ProfileName {
32 fn fmt(&self, b: &mut fmt::Formatter) -> fmt::Result {
33 match self {
34 ProfileName::Default => write!(b, "default"),
35 ProfileName::Named(name) => write!(b, "{name}"),
36 }
37 }
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
41pub struct AwsProfile {
42 pub config_comments: Vec<String>,
44
45 pub credentials_comments: Vec<String>,
47
48 pub is_production: bool,
50
51 pub is_locked: bool,
53
54 pub name: ProfileName,
56
57 pub region: Option<String>,
59
60 pub output: Option<String>,
62
63 pub aws_access_key_id: String,
65
66 pub aws_secret_access_key: String,
68}
69
70#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
71struct AwsConfig {
72 comments: Vec<String>,
73 is_production: bool,
74 is_locked: bool,
75 name: ProfileName,
76 region: Option<String>,
77 output: Option<String>,
78}
79
80#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
81struct AwsCredential {
82 comments: Vec<String>,
83 is_production: bool,
84 is_locked: bool,
85 name: ProfileName,
86 aws_access_key_id: String,
87 aws_secret_access_key: String,
88}
89
90#[derive(Debug)]
91pub struct AwsFile {
92 config: File,
93 credentials: File,
94}
95
96impl AwsFile {
97 pub fn open() -> Result<AwsFile> {
98 let home_dir = home_dir().expect("failed to locate home directory");
99
100 let aws_home = home_dir.join(".aws");
101 let config = OpenOptions::new()
102 .read(true)
103 .write(true)
104 .open(aws_home.join("config"))?;
105 let credentials = OpenOptions::new()
106 .read(true)
107 .write(true)
108 .open(aws_home.join("credentials"))?;
109
110 Ok(AwsFile {
111 config,
112 credentials,
113 })
114 }
115
116 pub fn flush(&mut self) -> Result<()> {
117 self.config.flush()?;
118 self.credentials.flush()?;
119
120 Ok(())
121 }
122
123 pub fn parse(&mut self) -> Result<Vec<AwsProfile>> {
124 let config = self.parse_config()?;
125 let config_names: Vec<_> = config.iter().map(|conf| conf.name.clone()).collect();
126 let mut config: HashMap<_, _> = config
127 .into_iter()
128 .map(|conf| (conf.name.clone(), conf))
129 .collect();
130
131 let credentials = self.parse_credentials()?;
132 let credentials_names: Vec<_> = credentials.iter().map(|cred| cred.name.clone()).collect();
133 let mut credentials: HashMap<_, _> = credentials
134 .into_iter()
135 .map(|cred| (cred.name.clone(), cred))
136 .collect();
137
138 let mut names = vec![];
139 let mut inserted = HashSet::new();
140 for name in config_names.iter().chain(&credentials_names) {
141 if inserted.insert(name) {
142 names.push(name);
143 }
144 }
145
146 names
147 .into_iter()
148 .map(|name| {
149 let conf = config
150 .remove(name)
151 .ok_or_else(|| anyhow!("config '{name}' not found",))?;
152 let cred = credentials
153 .remove(name)
154 .ok_or_else(|| anyhow!("credentials '{name}' not found",))?;
155
156 Ok(AwsProfile {
157 config_comments: conf.comments,
158 credentials_comments: cred.comments,
159 is_production: conf.is_production || cred.is_production,
160 is_locked: conf.is_locked || cred.is_locked,
161 name: name.clone(),
162 region: conf.region,
163 output: conf.output,
164 aws_access_key_id: cred.aws_access_key_id,
165 aws_secret_access_key: cred.aws_secret_access_key,
166 })
167 })
168 .collect()
169 }
170
171 fn parse_config(&mut self) -> Result<Vec<AwsConfig>> {
172 let mut buf = String::new();
173 self.config.seek(SeekFrom::Start(0))?;
174 self.config.read_to_string(&mut buf)?;
175 let lexer = &mut EntryLineLexer::new(&buf);
176 let lines = lexer.tokenize()?;
177 let entries = EntryLineParser::new(lines).parse()?;
178
179 entries
180 .into_iter()
181 .map(|entry| {
182 let name = if entry.header == "default" {
183 ProfileName::Default
184 } else {
185 match *entry.header.splitn(2, ' ').collect::<Vec<_>>() {
186 [lit_profile, name] if lit_profile == "profile" => name.into(),
187 _ => bail!("unexpected header in your config: {:?}", entry.header),
188 }
189 };
190 let region = entry.values.get("region").cloned();
191 let output = entry.values.get("output").cloned();
192 Ok(AwsConfig {
193 comments: entry.comments,
194 is_production: entry.is_production,
195 is_locked: entry.is_locked,
196 name,
197 region,
198 output,
199 })
200 })
201 .collect()
202 }
203
204 fn parse_credentials(&mut self) -> Result<Vec<AwsCredential>> {
205 let mut buf = String::new();
206 self.config.seek(SeekFrom::Start(0))?;
207 self.credentials.read_to_string(&mut buf)?;
208 let lexer = &mut EntryLineLexer::new(&buf);
209 let lines = lexer.tokenize()?;
210 let entries = EntryLineParser::new(lines).parse()?;
211
212 entries
213 .into_iter()
214 .map(|entry| {
215 let name = entry.header.into();
216 let aws_access_key_id = entry
217 .values
218 .get("aws_access_key_id")
219 .ok_or_else(|| {
220 anyhow!("failed to find 'aws_access_key_id' in your credentials")
221 })?
222 .to_string();
223 let aws_secret_access_key = entry
224 .values
225 .get("aws_secret_access_key")
226 .ok_or_else(|| {
227 anyhow!("failed to find 'aws_secret_access_key' in your credentials")
228 })?
229 .to_string();
230 Ok(AwsCredential {
231 comments: entry.comments,
232 is_production: entry.is_production,
233 is_locked: entry.is_locked,
234 name,
235 aws_access_key_id,
236 aws_secret_access_key,
237 })
238 })
239 .collect()
240 }
241
242 pub fn write(&mut self, profiles: &[AwsProfile]) -> Result<()> {
243 let config: Vec<_> = profiles
244 .iter()
245 .map(|profile| AwsConfig {
246 comments: profile.config_comments.clone(),
247 is_production: profile.is_production,
248 is_locked: profile.is_locked,
249 name: profile.name.clone(),
250 region: profile.region.clone(),
251 output: profile.output.clone(),
252 })
253 .collect();
254 let credentials: Vec<_> = profiles
255 .iter()
256 .map(|profile| AwsCredential {
257 comments: profile.credentials_comments.clone(),
258 is_production: profile.is_production,
259 is_locked: profile.is_locked,
260 name: profile.name.clone(),
261 aws_access_key_id: profile.aws_access_key_id.clone(),
262 aws_secret_access_key: profile.aws_secret_access_key.clone(),
263 })
264 .collect();
265 self.write_config(&config)?;
266 self.write_credentials(&credentials)?;
267
268 Ok(())
269 }
270
271 fn write_config(&mut self, config: &[AwsConfig]) -> Result<()> {
272 self.config.seek(SeekFrom::Start(0))?;
273 self.config.set_len(0)?;
274
275 let mut first = true;
276 for conf in config {
277 if !first {
278 writeln!(self.config)?;
279 }
280 first = false;
281
282 for comment in &conf.comments {
283 writeln!(self.config, "# {}", comment)?;
284 }
285
286 if conf.is_production {
287 writeln!(self.config, "# production")?;
288 }
289
290 let locked_prefix = if conf.is_locked { "# " } else { "" };
291
292 match &conf.name {
293 ProfileName::Named(name) => {
294 writeln!(self.config, "{}[profile {}]", locked_prefix, name)?
295 }
296 ProfileName::Default => writeln!(self.config, "{}[default]", locked_prefix)?,
297 }
298
299 if let Some(region) = &conf.region {
300 writeln!(self.config, "{}region = {}", locked_prefix, region)?;
301 }
302
303 if let Some(output) = &conf.output {
304 writeln!(self.config, "{}output = {}", locked_prefix, output)?;
305 }
306 }
307
308 Ok(())
309 }
310
311 fn write_credentials(&mut self, credentials: &[AwsCredential]) -> Result<()> {
312 self.credentials.seek(SeekFrom::Start(0))?;
313 self.credentials.set_len(0)?;
314
315 let mut first = true;
316 for cred in credentials {
317 if !first {
318 writeln!(self.credentials)?;
319 }
320 first = false;
321
322 for comment in &cred.comments {
323 writeln!(self.credentials, "# {}", comment)?;
324 }
325
326 if cred.is_production {
327 writeln!(self.credentials, "# production")?;
328 }
329
330 let locked_prefix = if cred.is_locked { "# " } else { "" };
331
332 writeln!(self.credentials, "{}[{}]", locked_prefix, cred.name)?;
333 writeln!(
334 self.credentials,
335 "{}aws_access_key_id = {}",
336 locked_prefix, cred.aws_access_key_id
337 )?;
338 writeln!(
339 self.credentials,
340 "{}aws_secret_access_key = {}",
341 locked_prefix, cred.aws_secret_access_key
342 )?;
343 }
344
345 Ok(())
346 }
347}