Skip to main content

purple_ssh/providers/
aws.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use hmac::{Hmac, Mac};
5use sha2::{Digest, Sha256};
6
7use super::{Provider, ProviderError, ProviderHost};
8
9pub struct Aws {
10    pub regions: Vec<String>,
11    pub profile: String,
12}
13
14/// All commonly available AWS regions with display names.
15/// Single source of truth. AWS_REGION_GROUPS references slices of this array.
16pub const AWS_REGIONS: &[(&str, &str)] = &[
17    // Americas (0..8)
18    ("us-east-1", "N. Virginia"),
19    ("us-east-2", "Ohio"),
20    ("us-west-1", "N. California"),
21    ("us-west-2", "Oregon"),
22    ("ca-central-1", "Canada Central"),
23    ("ca-west-1", "Canada West"),
24    ("mx-central-1", "Mexico Central"),
25    ("sa-east-1", "Sao Paulo"),
26    // Europe (8..16)
27    ("eu-west-1", "Ireland"),
28    ("eu-west-2", "London"),
29    ("eu-west-3", "Paris"),
30    ("eu-central-1", "Frankfurt"),
31    ("eu-central-2", "Zurich"),
32    ("eu-south-1", "Milan"),
33    ("eu-south-2", "Spain"),
34    ("eu-north-1", "Stockholm"),
35    // Asia Pacific (16..30)
36    ("ap-northeast-1", "Tokyo"),
37    ("ap-northeast-2", "Seoul"),
38    ("ap-northeast-3", "Osaka"),
39    ("ap-southeast-1", "Singapore"),
40    ("ap-southeast-2", "Sydney"),
41    ("ap-southeast-3", "Jakarta"),
42    ("ap-southeast-4", "Melbourne"),
43    ("ap-southeast-5", "Malaysia"),
44    ("ap-southeast-6", "New Zealand"),
45    ("ap-southeast-7", "Thailand"),
46    ("ap-east-1", "Hong Kong"),
47    ("ap-east-2", "Taipei"),
48    ("ap-south-1", "Mumbai"),
49    ("ap-south-2", "Hyderabad"),
50    // Middle East / Africa (30..34)
51    ("me-south-1", "Bahrain"),
52    ("me-central-1", "UAE"),
53    ("il-central-1", "Tel Aviv"),
54    ("af-south-1", "Cape Town"),
55];
56
57/// Region group labels with start..end indices into AWS_REGIONS.
58pub const AWS_REGION_GROUPS: &[(&str, usize, usize)] = &[
59    ("Americas", 0, 8),
60    ("Europe", 8, 16),
61    ("Asia Pacific", 16, 30),
62    ("Middle East / Africa", 30, 34),
63];
64
65// --- Credentials ---
66
67struct AwsCredentials {
68    access_key: String,
69    secret_key: String,
70}
71
72fn resolve_credentials(token: &str, profile: &str) -> Result<AwsCredentials, ProviderError> {
73    // Profile takes priority: read from ~/.aws/credentials
74    if !profile.is_empty() {
75        return read_credentials_file(profile);
76    }
77    // Token field: ACCESS_KEY_ID:SECRET_ACCESS_KEY
78    if let Some((ak, sk)) = token.split_once(':') {
79        if !ak.is_empty() && !sk.is_empty() {
80            return Ok(AwsCredentials {
81                access_key: ak.to_string(),
82                secret_key: sk.to_string(),
83            });
84        }
85    }
86    // Environment variables
87    if let (Ok(ak), Ok(sk)) = (
88        std::env::var("AWS_ACCESS_KEY_ID"),
89        std::env::var("AWS_SECRET_ACCESS_KEY"),
90    ) {
91        if !ak.is_empty() && !sk.is_empty() {
92            return Ok(AwsCredentials {
93                access_key: ak,
94                secret_key: sk,
95            });
96        }
97    }
98    Err(ProviderError::AuthFailed)
99}
100
101/// Parse AWS credentials from INI content (testable without filesystem).
102fn parse_credentials(content: &str, profile: &str) -> Option<AwsCredentials> {
103    let header = format!("[{}]", profile);
104    let mut in_section = false;
105    let mut access_key = String::new();
106    let mut secret_key = String::new();
107
108    for line in content.lines() {
109        let trimmed = line.trim();
110        if trimmed.starts_with('[') {
111            in_section = trimmed == header;
112            continue;
113        }
114        if !in_section {
115            continue;
116        }
117        if let Some((key, value)) = trimmed.split_once('=') {
118            match key.trim() {
119                "aws_access_key_id" => access_key = value.trim().to_string(),
120                "aws_secret_access_key" => secret_key = value.trim().to_string(),
121                _ => {}
122            }
123        }
124    }
125
126    if access_key.is_empty() || secret_key.is_empty() {
127        None
128    } else {
129        Some(AwsCredentials {
130            access_key,
131            secret_key,
132        })
133    }
134}
135
136fn read_credentials_file(profile: &str) -> Result<AwsCredentials, ProviderError> {
137    let path = dirs::home_dir()
138        .ok_or(ProviderError::AuthFailed)?
139        .join(".aws")
140        .join("credentials");
141    let content = std::fs::read_to_string(&path).map_err(|_| ProviderError::AuthFailed)?;
142    parse_credentials(&content, profile).ok_or(ProviderError::AuthFailed)
143}
144
145// --- SigV4 signing ---
146
147fn hex_encode(bytes: &[u8]) -> String {
148    bytes.iter().map(|b| format!("{:02x}", b)).collect()
149}
150
151fn sha256_hash(data: &[u8]) -> Vec<u8> {
152    let mut hasher = Sha256::new();
153    hasher.update(data);
154    hasher.finalize().to_vec()
155}
156
157fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
158    // INVARIANT: `Hmac::<Sha256>::new_from_slice` only fails when the MAC
159    // implementation rejects the key length. HMAC-SHA256 accepts keys of any
160    // length (RFC 2104 ยง2), so this branch is unreachable for Hmac<Sha256>.
161    let mut mac = Hmac::<Sha256>::new_from_slice(key)
162        .expect("Hmac::<Sha256>::new_from_slice accepts any key length (RFC 2104)");
163    mac.update(data);
164    mac.finalize().into_bytes().to_vec()
165}
166
167/// RFC 3986 URI encoding (delegates to shared implementation).
168fn uri_encode(s: &str) -> String {
169    super::percent_encode(s)
170}
171
172/// Format epoch seconds as (timestamp, datestamp) for SigV4.
173fn format_utc(epoch_secs: u64) -> (String, String) {
174    let d = super::epoch_to_date(epoch_secs);
175    let timestamp = format!(
176        "{:04}{:02}{:02}T{:02}{:02}{:02}Z",
177        d.year, d.month, d.day, d.hours, d.minutes, d.seconds,
178    );
179    let datestamp = format!("{:04}{:02}{:02}", d.year, d.month, d.day);
180    (timestamp, datestamp)
181}
182
183/// Build the SigV4 Authorization header value.
184fn sign_request(
185    creds: &AwsCredentials,
186    region: &str,
187    host: &str,
188    query_string: &str,
189    timestamp: &str,
190    datestamp: &str,
191) -> String {
192    let payload_hash = hex_encode(&sha256_hash(b""));
193    let canonical_headers = format!("host:{}\nx-amz-date:{}\n", host, timestamp);
194    let signed_headers = "host;x-amz-date";
195
196    let canonical_request = format!(
197        "GET\n/\n{}\n{}\n{}\n{}",
198        query_string, canonical_headers, signed_headers, payload_hash
199    );
200
201    let scope = format!("{}/{}/ec2/aws4_request", datestamp, region);
202    let string_to_sign = format!(
203        "AWS4-HMAC-SHA256\n{}\n{}\n{}",
204        timestamp,
205        scope,
206        hex_encode(&sha256_hash(canonical_request.as_bytes())),
207    );
208
209    let k_date = hmac_sha256(
210        format!("AWS4{}", creds.secret_key).as_bytes(),
211        datestamp.as_bytes(),
212    );
213    let k_region = hmac_sha256(&k_date, region.as_bytes());
214    let k_service = hmac_sha256(&k_region, b"ec2");
215    let k_signing = hmac_sha256(&k_service, b"aws4_request");
216    let signature = hex_encode(&hmac_sha256(&k_signing, string_to_sign.as_bytes()));
217
218    format!(
219        "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
220        creds.access_key, scope, signed_headers, signature
221    )
222}
223
224// --- XML response structs ---
225
226/// Generic wrapper for AWS XML lists that use repeated `<item>` elements.
227#[derive(serde::Deserialize, Debug)]
228#[serde(bound(deserialize = "T: serde::Deserialize<'de>"))]
229struct ItemList<T> {
230    #[serde(rename = "item", default = "Vec::new")]
231    item: Vec<T>,
232}
233
234impl<T> Default for ItemList<T> {
235    fn default() -> Self {
236        Self { item: Vec::new() }
237    }
238}
239
240#[derive(serde::Deserialize, Debug)]
241struct DescribeInstancesResponse {
242    #[serde(rename = "reservationSet", default)]
243    reservation_set: ItemList<Reservation>,
244    #[serde(rename = "nextToken", default)]
245    next_token: Option<String>,
246}
247
248#[derive(serde::Deserialize, Debug)]
249struct Reservation {
250    #[serde(rename = "instancesSet", default)]
251    instances_set: ItemList<Ec2Instance>,
252}
253
254#[derive(serde::Deserialize, Debug)]
255struct Ec2Instance {
256    #[serde(rename = "instanceId", default)]
257    instance_id: String,
258    #[serde(rename = "imageId", default)]
259    image_id: String,
260    #[serde(rename = "instanceState", default)]
261    instance_state: InstanceState,
262    #[serde(rename = "instanceType", default)]
263    instance_type: String,
264    #[serde(rename = "tagSet", default)]
265    tag_set: ItemList<Ec2Tag>,
266    #[serde(rename = "ipAddress", default)]
267    ip_address: Option<String>,
268    #[serde(rename = "privateIpAddress", default)]
269    private_ip_address: Option<String>,
270}
271
272#[derive(serde::Deserialize, Debug, Default)]
273struct InstanceState {
274    #[serde(default)]
275    name: String,
276}
277
278#[derive(serde::Deserialize, Debug)]
279struct Ec2Tag {
280    #[serde(default)]
281    key: String,
282    #[serde(default)]
283    value: String,
284}
285
286#[derive(serde::Deserialize, Debug)]
287struct DescribeImagesResponse {
288    #[serde(rename = "imagesSet", default)]
289    images_set: ItemList<ImageInfo>,
290}
291
292#[derive(serde::Deserialize, Debug)]
293struct ImageInfo {
294    #[serde(rename = "imageId", default)]
295    image_id: String,
296    #[serde(default)]
297    name: String,
298}
299
300// --- EC2 API ---
301
302fn param(key: &str, value: &str) -> (String, String) {
303    (key.to_string(), value.to_string())
304}
305
306/// Make a signed GET request to the EC2 API.
307fn ec2_get(
308    agent: &ureq::Agent,
309    creds: &AwsCredentials,
310    region: &str,
311    params: Vec<(String, String)>,
312) -> Result<String, ProviderError> {
313    let host = format!("ec2.{}.amazonaws.com", region);
314    let epoch = std::time::SystemTime::now()
315        .duration_since(std::time::UNIX_EPOCH)
316        .unwrap_or_default()
317        .as_secs();
318    let (timestamp, datestamp) = format_utc(epoch);
319
320    // Build sorted, URI-encoded query string (SigV4 requires sorted params)
321    let mut sorted: Vec<(String, String)> = params
322        .into_iter()
323        .map(|(k, v)| (uri_encode(&k), uri_encode(&v)))
324        .collect();
325    sorted.sort();
326    let query_string: String = sorted
327        .iter()
328        .map(|(k, v)| format!("{}={}", k, v))
329        .collect::<Vec<_>>()
330        .join("&");
331
332    let auth = sign_request(creds, region, &host, &query_string, &timestamp, &datestamp);
333    let url = format!("https://{}/?{}", host, query_string);
334
335    let mut resp = agent
336        .get(&url)
337        .header("Authorization", &auth)
338        .header("x-amz-date", &timestamp)
339        .call()
340        .map_err(super::map_ureq_error)?;
341
342    resp.body_mut()
343        .read_to_string()
344        .map_err(|e| ProviderError::Parse(e.to_string()))
345}
346
347/// Fetch all non-terminated instances in a region (handles pagination).
348fn describe_instances(
349    agent: &ureq::Agent,
350    creds: &AwsCredentials,
351    region: &str,
352    cancel: &AtomicBool,
353) -> Result<Vec<Ec2Instance>, ProviderError> {
354    let mut all = Vec::new();
355    let mut next_token: Option<String> = None;
356    let mut page = 0usize;
357
358    loop {
359        page += 1;
360        if page > 500 {
361            break;
362        }
363        if cancel.load(Ordering::Relaxed) {
364            return Err(ProviderError::Cancelled);
365        }
366
367        let mut params = vec![
368            param("Action", "DescribeInstances"),
369            param("Version", "2016-11-15"),
370        ];
371        if let Some(ref token) = next_token {
372            params.push(param("NextToken", token));
373        }
374
375        let body = ec2_get(agent, creds, region, params)?;
376        let resp: DescribeInstancesResponse = quick_xml::de::from_str(&body)
377            .map_err(|e| ProviderError::Parse(format!("{}: {}", region, e)))?;
378
379        for reservation in resp.reservation_set.item {
380            for instance in reservation.instances_set.item {
381                if instance.instance_state.name != "terminated"
382                    && instance.instance_state.name != "shutting-down"
383                {
384                    all.push(instance);
385                }
386            }
387        }
388
389        match resp.next_token {
390            Some(t) if !t.is_empty() => next_token = Some(t),
391            _ => break,
392        }
393    }
394
395    Ok(all)
396}
397
398/// Maximum AMI IDs per DescribeImages request to stay within AWS query limits.
399const AMI_BATCH_SIZE: usize = 100;
400
401/// Fetch AMI ID to name mapping (best effort, returns empty map on failure).
402/// Batches requests to stay within AWS API limits.
403fn fetch_image_names(
404    agent: &ureq::Agent,
405    creds: &AwsCredentials,
406    region: &str,
407    image_ids: &[String],
408) -> Result<HashMap<String, String>, ProviderError> {
409    if image_ids.is_empty() {
410        return Ok(HashMap::new());
411    }
412
413    let mut map = HashMap::new();
414    for chunk in image_ids.chunks(AMI_BATCH_SIZE) {
415        let mut params = vec![
416            param("Action", "DescribeImages"),
417            param("Version", "2016-11-15"),
418        ];
419        for (i, id) in chunk.iter().enumerate() {
420            params.push(param(&format!("ImageId.{}", i + 1), id));
421        }
422
423        let body = ec2_get(agent, creds, region, params)?;
424        let resp: DescribeImagesResponse = quick_xml::de::from_str(&body)
425            .map_err(|e| ProviderError::Parse(format!("{}: {}", region, e)))?;
426
427        for image in resp.images_set.item {
428            if !image.name.is_empty() {
429                map.insert(image.image_id, image.name);
430            }
431        }
432    }
433    Ok(map)
434}
435
436/// Extract Name tag value and user tags from an instance's tag set.
437/// Filters out aws:* tags. Returns (name, tags) where tags are values only.
438fn extract_tags(tag_set: &[Ec2Tag]) -> (String, Vec<String>) {
439    let mut name = String::new();
440    let mut tags = Vec::new();
441    for tag in tag_set {
442        if tag.key == "Name" {
443            name = tag.value.clone();
444        } else if !tag.key.starts_with("aws:") && !tag.value.is_empty() {
445            tags.push(tag.value.clone());
446        }
447    }
448    tags.sort();
449    (name, tags)
450}
451
452// --- Provider trait ---
453
454impl Provider for Aws {
455    fn name(&self) -> &str {
456        "aws"
457    }
458
459    fn short_label(&self) -> &str {
460        "aws"
461    }
462
463    fn fetch_hosts_cancellable(
464        &self,
465        token: &str,
466        cancel: &AtomicBool,
467    ) -> Result<Vec<ProviderHost>, ProviderError> {
468        self.fetch_hosts_with_progress(token, cancel, &|_| {})
469    }
470
471    fn fetch_hosts_with_progress(
472        &self,
473        token: &str,
474        cancel: &AtomicBool,
475        progress: &dyn Fn(&str),
476    ) -> Result<Vec<ProviderHost>, ProviderError> {
477        if self.regions.is_empty() {
478            return Err(ProviderError::Http(
479                "No AWS regions configured. Add regions in the provider settings.".to_string(),
480            ));
481        }
482
483        let valid_codes: HashSet<&str> = AWS_REGIONS.iter().map(|(c, _)| *c).collect();
484        for region in &self.regions {
485            if !valid_codes.contains(region.as_str()) {
486                return Err(ProviderError::Http(format!(
487                    "Unknown AWS region '{}'. Check your provider settings.",
488                    region
489                )));
490            }
491        }
492
493        let creds = resolve_credentials(token, &self.profile)?;
494        let agent = super::http_agent();
495        let total_regions = self.regions.len();
496        let mut all_hosts = Vec::new();
497        let mut failed_regions = 0usize;
498
499        for (i, region) in self.regions.iter().enumerate() {
500            if cancel.load(Ordering::Relaxed) {
501                return Err(ProviderError::Cancelled);
502            }
503
504            progress(&format!(
505                "Fetching {} ({}/{})...",
506                region,
507                i + 1,
508                total_regions
509            ));
510
511            let instances = match describe_instances(&agent, &creds, region, cancel) {
512                Ok(instances) => instances,
513                Err(ProviderError::Cancelled) => return Err(ProviderError::Cancelled),
514                Err(ProviderError::AuthFailed) => return Err(ProviderError::AuthFailed),
515                Err(ProviderError::RateLimited) => return Err(ProviderError::RateLimited),
516                Err(_) => {
517                    failed_regions += 1;
518                    continue;
519                }
520            };
521
522            // Collect unique AMI IDs for OS metadata lookup
523            let ami_ids: Vec<String> = {
524                let mut set = HashSet::new();
525                for inst in &instances {
526                    if !inst.image_id.is_empty() {
527                        set.insert(inst.image_id.clone());
528                    }
529                }
530                set.into_iter().collect()
531            };
532
533            // Fetch AMI names (best effort)
534            let ami_names = if !ami_ids.is_empty() {
535                progress(&format!("Resolving AMIs for {}...", region));
536                fetch_image_names(&agent, &creds, region, &ami_ids).unwrap_or_default()
537            } else {
538                HashMap::new()
539            };
540
541            for instance in instances {
542                let ip = match instance.ip_address {
543                    Some(ref ip) if !ip.is_empty() => ip.clone(),
544                    _ => match instance.private_ip_address {
545                        Some(ref ip) if !ip.is_empty() => ip.clone(),
546                        _ => continue,
547                    },
548                };
549
550                let (name, tags) = extract_tags(&instance.tag_set.item);
551                let name = if name.is_empty() {
552                    instance.instance_id.clone()
553                } else {
554                    name
555                };
556
557                let mut metadata = Vec::new();
558                metadata.push(("region".to_string(), region.clone()));
559                if !instance.instance_type.is_empty() {
560                    metadata.push(("instance".to_string(), instance.instance_type.clone()));
561                }
562                if let Some(os_name) = ami_names.get(&instance.image_id) {
563                    metadata.push(("os".to_string(), os_name.clone()));
564                }
565                if !instance.instance_state.name.is_empty() {
566                    metadata.push(("status".to_string(), instance.instance_state.name.clone()));
567                }
568
569                all_hosts.push(ProviderHost {
570                    server_id: instance.instance_id,
571                    name,
572                    ip,
573                    tags,
574                    metadata,
575                });
576            }
577        }
578
579        // Summary
580        let mut parts = vec![format!("{} instances", all_hosts.len())];
581        if failed_regions > 0 {
582            parts.push(format!(
583                "{} of {} regions failed",
584                failed_regions, total_regions
585            ));
586        }
587        progress(&parts.join(", "));
588
589        if failed_regions > 0 {
590            if all_hosts.is_empty() {
591                return Err(ProviderError::Http(format!(
592                    "All {} regions failed. Check your credentials and region configuration.",
593                    total_regions,
594                )));
595            }
596            return Err(ProviderError::PartialResult {
597                hosts: all_hosts,
598                failures: failed_regions,
599                total: total_regions,
600            });
601        }
602
603        Ok(all_hosts)
604    }
605}
606
607#[cfg(test)]
608#[path = "aws_tests.rs"]
609mod tests;