1use std::collections::{HashMap, HashSet};
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use hmac::{Hmac, Mac};
5use sha2::{Digest, Sha256};
6
7use super::{Provider, ProviderError, ProviderHost};
8
9pub struct Aws {
10 pub regions: Vec<String>,
11 pub profile: String,
12}
13
14pub const AWS_REGIONS: &[(&str, &str)] = &[
17 ("us-east-1", "N. Virginia"),
19 ("us-east-2", "Ohio"),
20 ("us-west-1", "N. California"),
21 ("us-west-2", "Oregon"),
22 ("ca-central-1", "Canada Central"),
23 ("ca-west-1", "Canada West"),
24 ("mx-central-1", "Mexico Central"),
25 ("sa-east-1", "Sao Paulo"),
26 ("eu-west-1", "Ireland"),
28 ("eu-west-2", "London"),
29 ("eu-west-3", "Paris"),
30 ("eu-central-1", "Frankfurt"),
31 ("eu-central-2", "Zurich"),
32 ("eu-south-1", "Milan"),
33 ("eu-south-2", "Spain"),
34 ("eu-north-1", "Stockholm"),
35 ("ap-northeast-1", "Tokyo"),
37 ("ap-northeast-2", "Seoul"),
38 ("ap-northeast-3", "Osaka"),
39 ("ap-southeast-1", "Singapore"),
40 ("ap-southeast-2", "Sydney"),
41 ("ap-southeast-3", "Jakarta"),
42 ("ap-southeast-4", "Melbourne"),
43 ("ap-southeast-5", "Malaysia"),
44 ("ap-southeast-6", "New Zealand"),
45 ("ap-southeast-7", "Thailand"),
46 ("ap-east-1", "Hong Kong"),
47 ("ap-east-2", "Taipei"),
48 ("ap-south-1", "Mumbai"),
49 ("ap-south-2", "Hyderabad"),
50 ("me-south-1", "Bahrain"),
52 ("me-central-1", "UAE"),
53 ("il-central-1", "Tel Aviv"),
54 ("af-south-1", "Cape Town"),
55];
56
57pub const AWS_REGION_GROUPS: &[(&str, usize, usize)] = &[
59 ("Americas", 0, 8),
60 ("Europe", 8, 16),
61 ("Asia Pacific", 16, 30),
62 ("Middle East / Africa", 30, 34),
63];
64
65struct AwsCredentials {
68 access_key: String,
69 secret_key: String,
70}
71
72fn resolve_credentials(token: &str, profile: &str) -> Result<AwsCredentials, ProviderError> {
73 if !profile.is_empty() {
75 return read_credentials_file(profile);
76 }
77 if let Some((ak, sk)) = token.split_once(':') {
79 if !ak.is_empty() && !sk.is_empty() {
80 return Ok(AwsCredentials {
81 access_key: ak.to_string(),
82 secret_key: sk.to_string(),
83 });
84 }
85 }
86 if let (Ok(ak), Ok(sk)) = (
88 std::env::var("AWS_ACCESS_KEY_ID"),
89 std::env::var("AWS_SECRET_ACCESS_KEY"),
90 ) {
91 if !ak.is_empty() && !sk.is_empty() {
92 return Ok(AwsCredentials {
93 access_key: ak,
94 secret_key: sk,
95 });
96 }
97 }
98 Err(ProviderError::AuthFailed)
99}
100
101fn parse_credentials(content: &str, profile: &str) -> Option<AwsCredentials> {
103 let header = format!("[{}]", profile);
104 let mut in_section = false;
105 let mut access_key = String::new();
106 let mut secret_key = String::new();
107
108 for line in content.lines() {
109 let trimmed = line.trim();
110 if trimmed.starts_with('[') {
111 in_section = trimmed == header;
112 continue;
113 }
114 if !in_section {
115 continue;
116 }
117 if let Some((key, value)) = trimmed.split_once('=') {
118 match key.trim() {
119 "aws_access_key_id" => access_key = value.trim().to_string(),
120 "aws_secret_access_key" => secret_key = value.trim().to_string(),
121 _ => {}
122 }
123 }
124 }
125
126 if access_key.is_empty() || secret_key.is_empty() {
127 None
128 } else {
129 Some(AwsCredentials {
130 access_key,
131 secret_key,
132 })
133 }
134}
135
136fn read_credentials_file(profile: &str) -> Result<AwsCredentials, ProviderError> {
137 let path = dirs::home_dir()
138 .ok_or(ProviderError::AuthFailed)?
139 .join(".aws")
140 .join("credentials");
141 let content = std::fs::read_to_string(&path).map_err(|_| ProviderError::AuthFailed)?;
142 parse_credentials(&content, profile).ok_or(ProviderError::AuthFailed)
143}
144
145fn hex_encode(bytes: &[u8]) -> String {
148 bytes.iter().map(|b| format!("{:02x}", b)).collect()
149}
150
151fn sha256_hash(data: &[u8]) -> Vec<u8> {
152 let mut hasher = Sha256::new();
153 hasher.update(data);
154 hasher.finalize().to_vec()
155}
156
157fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
158 let mut mac = Hmac::<Sha256>::new_from_slice(key)
162 .expect("Hmac::<Sha256>::new_from_slice accepts any key length (RFC 2104)");
163 mac.update(data);
164 mac.finalize().into_bytes().to_vec()
165}
166
167fn uri_encode(s: &str) -> String {
169 super::percent_encode(s)
170}
171
172fn format_utc(epoch_secs: u64) -> (String, String) {
174 let d = super::epoch_to_date(epoch_secs);
175 let timestamp = format!(
176 "{:04}{:02}{:02}T{:02}{:02}{:02}Z",
177 d.year, d.month, d.day, d.hours, d.minutes, d.seconds,
178 );
179 let datestamp = format!("{:04}{:02}{:02}", d.year, d.month, d.day);
180 (timestamp, datestamp)
181}
182
183fn sign_request(
185 creds: &AwsCredentials,
186 region: &str,
187 host: &str,
188 query_string: &str,
189 timestamp: &str,
190 datestamp: &str,
191) -> String {
192 let payload_hash = hex_encode(&sha256_hash(b""));
193 let canonical_headers = format!("host:{}\nx-amz-date:{}\n", host, timestamp);
194 let signed_headers = "host;x-amz-date";
195
196 let canonical_request = format!(
197 "GET\n/\n{}\n{}\n{}\n{}",
198 query_string, canonical_headers, signed_headers, payload_hash
199 );
200
201 let scope = format!("{}/{}/ec2/aws4_request", datestamp, region);
202 let string_to_sign = format!(
203 "AWS4-HMAC-SHA256\n{}\n{}\n{}",
204 timestamp,
205 scope,
206 hex_encode(&sha256_hash(canonical_request.as_bytes())),
207 );
208
209 let k_date = hmac_sha256(
210 format!("AWS4{}", creds.secret_key).as_bytes(),
211 datestamp.as_bytes(),
212 );
213 let k_region = hmac_sha256(&k_date, region.as_bytes());
214 let k_service = hmac_sha256(&k_region, b"ec2");
215 let k_signing = hmac_sha256(&k_service, b"aws4_request");
216 let signature = hex_encode(&hmac_sha256(&k_signing, string_to_sign.as_bytes()));
217
218 format!(
219 "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
220 creds.access_key, scope, signed_headers, signature
221 )
222}
223
224#[derive(serde::Deserialize, Debug)]
228#[serde(bound(deserialize = "T: serde::Deserialize<'de>"))]
229struct ItemList<T> {
230 #[serde(rename = "item", default = "Vec::new")]
231 item: Vec<T>,
232}
233
234impl<T> Default for ItemList<T> {
235 fn default() -> Self {
236 Self { item: Vec::new() }
237 }
238}
239
240#[derive(serde::Deserialize, Debug)]
241struct DescribeInstancesResponse {
242 #[serde(rename = "reservationSet", default)]
243 reservation_set: ItemList<Reservation>,
244 #[serde(rename = "nextToken", default)]
245 next_token: Option<String>,
246}
247
248#[derive(serde::Deserialize, Debug)]
249struct Reservation {
250 #[serde(rename = "instancesSet", default)]
251 instances_set: ItemList<Ec2Instance>,
252}
253
254#[derive(serde::Deserialize, Debug)]
255struct Ec2Instance {
256 #[serde(rename = "instanceId", default)]
257 instance_id: String,
258 #[serde(rename = "imageId", default)]
259 image_id: String,
260 #[serde(rename = "instanceState", default)]
261 instance_state: InstanceState,
262 #[serde(rename = "instanceType", default)]
263 instance_type: String,
264 #[serde(rename = "tagSet", default)]
265 tag_set: ItemList<Ec2Tag>,
266 #[serde(rename = "ipAddress", default)]
267 ip_address: Option<String>,
268 #[serde(rename = "privateIpAddress", default)]
269 private_ip_address: Option<String>,
270}
271
272#[derive(serde::Deserialize, Debug, Default)]
273struct InstanceState {
274 #[serde(default)]
275 name: String,
276}
277
278#[derive(serde::Deserialize, Debug)]
279struct Ec2Tag {
280 #[serde(default)]
281 key: String,
282 #[serde(default)]
283 value: String,
284}
285
286#[derive(serde::Deserialize, Debug)]
287struct DescribeImagesResponse {
288 #[serde(rename = "imagesSet", default)]
289 images_set: ItemList<ImageInfo>,
290}
291
292#[derive(serde::Deserialize, Debug)]
293struct ImageInfo {
294 #[serde(rename = "imageId", default)]
295 image_id: String,
296 #[serde(default)]
297 name: String,
298}
299
300fn param(key: &str, value: &str) -> (String, String) {
303 (key.to_string(), value.to_string())
304}
305
306fn ec2_get(
308 agent: &ureq::Agent,
309 creds: &AwsCredentials,
310 region: &str,
311 params: Vec<(String, String)>,
312) -> Result<String, ProviderError> {
313 let host = format!("ec2.{}.amazonaws.com", region);
314 let epoch = std::time::SystemTime::now()
315 .duration_since(std::time::UNIX_EPOCH)
316 .unwrap_or_default()
317 .as_secs();
318 let (timestamp, datestamp) = format_utc(epoch);
319
320 let mut sorted: Vec<(String, String)> = params
322 .into_iter()
323 .map(|(k, v)| (uri_encode(&k), uri_encode(&v)))
324 .collect();
325 sorted.sort();
326 let query_string: String = sorted
327 .iter()
328 .map(|(k, v)| format!("{}={}", k, v))
329 .collect::<Vec<_>>()
330 .join("&");
331
332 let auth = sign_request(creds, region, &host, &query_string, ×tamp, &datestamp);
333 let url = format!("https://{}/?{}", host, query_string);
334
335 let mut resp = agent
336 .get(&url)
337 .header("Authorization", &auth)
338 .header("x-amz-date", ×tamp)
339 .call()
340 .map_err(super::map_ureq_error)?;
341
342 resp.body_mut()
343 .read_to_string()
344 .map_err(|e| ProviderError::Parse(e.to_string()))
345}
346
347fn describe_instances(
349 agent: &ureq::Agent,
350 creds: &AwsCredentials,
351 region: &str,
352 cancel: &AtomicBool,
353) -> Result<Vec<Ec2Instance>, ProviderError> {
354 let mut all = Vec::new();
355 let mut next_token: Option<String> = None;
356 let mut page = 0usize;
357
358 loop {
359 page += 1;
360 if page > 500 {
361 break;
362 }
363 if cancel.load(Ordering::Relaxed) {
364 return Err(ProviderError::Cancelled);
365 }
366
367 let mut params = vec![
368 param("Action", "DescribeInstances"),
369 param("Version", "2016-11-15"),
370 ];
371 if let Some(ref token) = next_token {
372 params.push(param("NextToken", token));
373 }
374
375 let body = ec2_get(agent, creds, region, params)?;
376 let resp: DescribeInstancesResponse = quick_xml::de::from_str(&body)
377 .map_err(|e| ProviderError::Parse(format!("{}: {}", region, e)))?;
378
379 for reservation in resp.reservation_set.item {
380 for instance in reservation.instances_set.item {
381 if instance.instance_state.name != "terminated"
382 && instance.instance_state.name != "shutting-down"
383 {
384 all.push(instance);
385 }
386 }
387 }
388
389 match resp.next_token {
390 Some(t) if !t.is_empty() => next_token = Some(t),
391 _ => break,
392 }
393 }
394
395 Ok(all)
396}
397
398const AMI_BATCH_SIZE: usize = 100;
400
401fn fetch_image_names(
404 agent: &ureq::Agent,
405 creds: &AwsCredentials,
406 region: &str,
407 image_ids: &[String],
408) -> Result<HashMap<String, String>, ProviderError> {
409 if image_ids.is_empty() {
410 return Ok(HashMap::new());
411 }
412
413 let mut map = HashMap::new();
414 for chunk in image_ids.chunks(AMI_BATCH_SIZE) {
415 let mut params = vec![
416 param("Action", "DescribeImages"),
417 param("Version", "2016-11-15"),
418 ];
419 for (i, id) in chunk.iter().enumerate() {
420 params.push(param(&format!("ImageId.{}", i + 1), id));
421 }
422
423 let body = ec2_get(agent, creds, region, params)?;
424 let resp: DescribeImagesResponse = quick_xml::de::from_str(&body)
425 .map_err(|e| ProviderError::Parse(format!("{}: {}", region, e)))?;
426
427 for image in resp.images_set.item {
428 if !image.name.is_empty() {
429 map.insert(image.image_id, image.name);
430 }
431 }
432 }
433 Ok(map)
434}
435
436fn extract_tags(tag_set: &[Ec2Tag]) -> (String, Vec<String>) {
439 let mut name = String::new();
440 let mut tags = Vec::new();
441 for tag in tag_set {
442 if tag.key == "Name" {
443 name = tag.value.clone();
444 } else if !tag.key.starts_with("aws:") && !tag.value.is_empty() {
445 tags.push(tag.value.clone());
446 }
447 }
448 tags.sort();
449 (name, tags)
450}
451
452impl Provider for Aws {
455 fn name(&self) -> &str {
456 "aws"
457 }
458
459 fn short_label(&self) -> &str {
460 "aws"
461 }
462
463 fn fetch_hosts_cancellable(
464 &self,
465 token: &str,
466 cancel: &AtomicBool,
467 ) -> Result<Vec<ProviderHost>, ProviderError> {
468 self.fetch_hosts_with_progress(token, cancel, &|_| {})
469 }
470
471 fn fetch_hosts_with_progress(
472 &self,
473 token: &str,
474 cancel: &AtomicBool,
475 progress: &dyn Fn(&str),
476 ) -> Result<Vec<ProviderHost>, ProviderError> {
477 if self.regions.is_empty() {
478 return Err(ProviderError::Http(
479 "No AWS regions configured. Add regions in the provider settings.".to_string(),
480 ));
481 }
482
483 let valid_codes: HashSet<&str> = AWS_REGIONS.iter().map(|(c, _)| *c).collect();
484 for region in &self.regions {
485 if !valid_codes.contains(region.as_str()) {
486 return Err(ProviderError::Http(format!(
487 "Unknown AWS region '{}'. Check your provider settings.",
488 region
489 )));
490 }
491 }
492
493 let creds = resolve_credentials(token, &self.profile)?;
494 let agent = super::http_agent();
495 let total_regions = self.regions.len();
496 let mut all_hosts = Vec::new();
497 let mut failed_regions = 0usize;
498
499 for (i, region) in self.regions.iter().enumerate() {
500 if cancel.load(Ordering::Relaxed) {
501 return Err(ProviderError::Cancelled);
502 }
503
504 progress(&format!(
505 "Fetching {} ({}/{})...",
506 region,
507 i + 1,
508 total_regions
509 ));
510
511 let instances = match describe_instances(&agent, &creds, region, cancel) {
512 Ok(instances) => instances,
513 Err(ProviderError::Cancelled) => return Err(ProviderError::Cancelled),
514 Err(ProviderError::AuthFailed) => return Err(ProviderError::AuthFailed),
515 Err(ProviderError::RateLimited) => return Err(ProviderError::RateLimited),
516 Err(_) => {
517 failed_regions += 1;
518 continue;
519 }
520 };
521
522 let ami_ids: Vec<String> = {
524 let mut set = HashSet::new();
525 for inst in &instances {
526 if !inst.image_id.is_empty() {
527 set.insert(inst.image_id.clone());
528 }
529 }
530 set.into_iter().collect()
531 };
532
533 let ami_names = if !ami_ids.is_empty() {
535 progress(&format!("Resolving AMIs for {}...", region));
536 fetch_image_names(&agent, &creds, region, &ami_ids).unwrap_or_default()
537 } else {
538 HashMap::new()
539 };
540
541 for instance in instances {
542 let ip = match instance.ip_address {
543 Some(ref ip) if !ip.is_empty() => ip.clone(),
544 _ => match instance.private_ip_address {
545 Some(ref ip) if !ip.is_empty() => ip.clone(),
546 _ => continue,
547 },
548 };
549
550 let (name, tags) = extract_tags(&instance.tag_set.item);
551 let name = if name.is_empty() {
552 instance.instance_id.clone()
553 } else {
554 name
555 };
556
557 let mut metadata = Vec::new();
558 metadata.push(("region".to_string(), region.clone()));
559 if !instance.instance_type.is_empty() {
560 metadata.push(("instance".to_string(), instance.instance_type.clone()));
561 }
562 if let Some(os_name) = ami_names.get(&instance.image_id) {
563 metadata.push(("os".to_string(), os_name.clone()));
564 }
565 if !instance.instance_state.name.is_empty() {
566 metadata.push(("status".to_string(), instance.instance_state.name.clone()));
567 }
568
569 all_hosts.push(ProviderHost {
570 server_id: instance.instance_id,
571 name,
572 ip,
573 tags,
574 metadata,
575 });
576 }
577 }
578
579 let mut parts = vec![format!("{} instances", all_hosts.len())];
581 if failed_regions > 0 {
582 parts.push(format!(
583 "{} of {} regions failed",
584 failed_regions, total_regions
585 ));
586 }
587 progress(&parts.join(", "));
588
589 if failed_regions > 0 {
590 if all_hosts.is_empty() {
591 return Err(ProviderError::Http(format!(
592 "All {} regions failed. Check your credentials and region configuration.",
593 total_regions,
594 )));
595 }
596 return Err(ProviderError::PartialResult {
597 hosts: all_hosts,
598 failures: failed_regions,
599 total: total_regions,
600 });
601 }
602
603 Ok(all_hosts)
604 }
605}
606
607#[cfg(test)]
608#[path = "aws_tests.rs"]
609mod tests;