Skip to main content

purple_ssh/providers/
aws.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use hmac::{Hmac, Mac};
5use sha2::{Digest, Sha256};
6
7use super::{Provider, ProviderError, ProviderHost};
8
9pub struct Aws {
10    pub regions: Vec<String>,
11    pub profile: String,
12}
13
14/// All commonly available AWS regions with display names.
15/// Single source of truth. AWS_REGION_GROUPS references slices of this array.
16pub const AWS_REGIONS: &[(&str, &str)] = &[
17    // Americas (0..8)
18    ("us-east-1", "N. Virginia"),
19    ("us-east-2", "Ohio"),
20    ("us-west-1", "N. California"),
21    ("us-west-2", "Oregon"),
22    ("ca-central-1", "Canada Central"),
23    ("ca-west-1", "Canada West"),
24    ("mx-central-1", "Mexico Central"),
25    ("sa-east-1", "Sao Paulo"),
26    // Europe (8..16)
27    ("eu-west-1", "Ireland"),
28    ("eu-west-2", "London"),
29    ("eu-west-3", "Paris"),
30    ("eu-central-1", "Frankfurt"),
31    ("eu-central-2", "Zurich"),
32    ("eu-south-1", "Milan"),
33    ("eu-south-2", "Spain"),
34    ("eu-north-1", "Stockholm"),
35    // Asia Pacific (16..30)
36    ("ap-northeast-1", "Tokyo"),
37    ("ap-northeast-2", "Seoul"),
38    ("ap-northeast-3", "Osaka"),
39    ("ap-southeast-1", "Singapore"),
40    ("ap-southeast-2", "Sydney"),
41    ("ap-southeast-3", "Jakarta"),
42    ("ap-southeast-4", "Melbourne"),
43    ("ap-southeast-5", "Malaysia"),
44    ("ap-southeast-6", "New Zealand"),
45    ("ap-southeast-7", "Thailand"),
46    ("ap-east-1", "Hong Kong"),
47    ("ap-east-2", "Taipei"),
48    ("ap-south-1", "Mumbai"),
49    ("ap-south-2", "Hyderabad"),
50    // Middle East / Africa (30..34)
51    ("me-south-1", "Bahrain"),
52    ("me-central-1", "UAE"),
53    ("il-central-1", "Tel Aviv"),
54    ("af-south-1", "Cape Town"),
55];
56
57/// Region group labels with start..end indices into AWS_REGIONS.
58pub const AWS_REGION_GROUPS: &[(&str, usize, usize)] = &[
59    ("Americas", 0, 8),
60    ("Europe", 8, 16),
61    ("Asia Pacific", 16, 30),
62    ("Middle East / Africa", 30, 34),
63];
64
65// --- Credentials ---
66
67struct AwsCredentials {
68    access_key: String,
69    secret_key: String,
70}
71
72fn resolve_credentials(
73    token: &str,
74    profile: &str,
75    env: &crate::runtime::env::Env,
76) -> Result<AwsCredentials, ProviderError> {
77    // Profile takes priority: read from ~/.aws/credentials
78    if !profile.is_empty() {
79        return read_credentials_file(profile, env);
80    }
81    // Token field: ACCESS_KEY_ID:SECRET_ACCESS_KEY
82    if let Some((ak, sk)) = token.split_once(':') {
83        if !ak.is_empty() && !sk.is_empty() {
84            return Ok(AwsCredentials {
85                access_key: ak.to_string(),
86                secret_key: sk.to_string(),
87            });
88        }
89    }
90    // Environment variables, from the injected snapshot.
91    if let Some((ak, sk)) = env.aws_credentials() {
92        if !ak.is_empty() && !sk.is_empty() {
93            return Ok(AwsCredentials {
94                access_key: ak.to_string(),
95                secret_key: sk.to_string(),
96            });
97        }
98    }
99    Err(ProviderError::AuthFailed)
100}
101
102/// Parse AWS credentials from INI content (testable without filesystem).
103fn parse_credentials(content: &str, profile: &str) -> Option<AwsCredentials> {
104    let header = format!("[{}]", profile);
105    let mut in_section = false;
106    let mut access_key = String::new();
107    let mut secret_key = String::new();
108
109    for line in content.lines() {
110        let trimmed = line.trim();
111        if trimmed.starts_with('[') {
112            in_section = trimmed == header;
113            continue;
114        }
115        if !in_section {
116            continue;
117        }
118        if let Some((key, value)) = trimmed.split_once('=') {
119            match key.trim() {
120                "aws_access_key_id" => access_key = value.trim().to_string(),
121                "aws_secret_access_key" => secret_key = value.trim().to_string(),
122                _ => {}
123            }
124        }
125    }
126
127    if access_key.is_empty() || secret_key.is_empty() {
128        None
129    } else {
130        Some(AwsCredentials {
131            access_key,
132            secret_key,
133        })
134    }
135}
136
137fn read_credentials_file(
138    profile: &str,
139    env: &crate::runtime::env::Env,
140) -> Result<AwsCredentials, ProviderError> {
141    let path = env
142        .paths()
143        .ok_or(ProviderError::AuthFailed)?
144        .aws_credentials_file();
145    let content = std::fs::read_to_string(&path).map_err(|_| ProviderError::AuthFailed)?;
146    parse_credentials(&content, profile).ok_or(ProviderError::AuthFailed)
147}
148
149// --- SigV4 signing ---
150
151fn hex_encode(bytes: &[u8]) -> String {
152    bytes.iter().map(|b| format!("{:02x}", b)).collect()
153}
154
155fn sha256_hash(data: &[u8]) -> Vec<u8> {
156    let mut hasher = Sha256::new();
157    hasher.update(data);
158    hasher.finalize().to_vec()
159}
160
161fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
162    // INVARIANT: `Hmac::<Sha256>::new_from_slice` only fails when the MAC
163    // implementation rejects the key length. HMAC-SHA256 accepts keys of any
164    // length (RFC 2104 ยง2), so this branch is unreachable for Hmac<Sha256>.
165    let mut mac = Hmac::<Sha256>::new_from_slice(key)
166        .expect("Hmac::<Sha256>::new_from_slice accepts any key length (RFC 2104)");
167    mac.update(data);
168    mac.finalize().into_bytes().to_vec()
169}
170
171/// RFC 3986 URI encoding (delegates to shared implementation).
172fn uri_encode(s: &str) -> String {
173    super::percent_encode(s)
174}
175
176/// Format epoch seconds as (timestamp, datestamp) for SigV4.
177fn format_utc(epoch_secs: u64) -> (String, String) {
178    let d = super::epoch_to_date(epoch_secs);
179    let timestamp = format!(
180        "{:04}{:02}{:02}T{:02}{:02}{:02}Z",
181        d.year, d.month, d.day, d.hours, d.minutes, d.seconds,
182    );
183    let datestamp = format!("{:04}{:02}{:02}", d.year, d.month, d.day);
184    (timestamp, datestamp)
185}
186
187/// Build the SigV4 Authorization header value.
188fn sign_request(
189    creds: &AwsCredentials,
190    region: &str,
191    host: &str,
192    query_string: &str,
193    timestamp: &str,
194    datestamp: &str,
195) -> String {
196    let payload_hash = hex_encode(&sha256_hash(b""));
197    let canonical_headers = format!("host:{}\nx-amz-date:{}\n", host, timestamp);
198    let signed_headers = "host;x-amz-date";
199
200    let canonical_request = format!(
201        "GET\n/\n{}\n{}\n{}\n{}",
202        query_string, canonical_headers, signed_headers, payload_hash
203    );
204
205    let scope = format!("{}/{}/ec2/aws4_request", datestamp, region);
206    let string_to_sign = format!(
207        "AWS4-HMAC-SHA256\n{}\n{}\n{}",
208        timestamp,
209        scope,
210        hex_encode(&sha256_hash(canonical_request.as_bytes())),
211    );
212
213    let k_date = hmac_sha256(
214        format!("AWS4{}", creds.secret_key).as_bytes(),
215        datestamp.as_bytes(),
216    );
217    let k_region = hmac_sha256(&k_date, region.as_bytes());
218    let k_service = hmac_sha256(&k_region, b"ec2");
219    let k_signing = hmac_sha256(&k_service, b"aws4_request");
220    let signature = hex_encode(&hmac_sha256(&k_signing, string_to_sign.as_bytes()));
221
222    format!(
223        "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
224        creds.access_key, scope, signed_headers, signature
225    )
226}
227
228// --- XML response structs ---
229
230/// Generic wrapper for AWS XML lists that use repeated `<item>` elements.
231#[derive(serde::Deserialize, Debug)]
232#[serde(bound(deserialize = "T: serde::Deserialize<'de>"))]
233struct ItemList<T> {
234    #[serde(rename = "item", default = "Vec::new")]
235    item: Vec<T>,
236}
237
238impl<T> Default for ItemList<T> {
239    fn default() -> Self {
240        Self { item: Vec::new() }
241    }
242}
243
244#[derive(serde::Deserialize, Debug)]
245struct DescribeInstancesResponse {
246    #[serde(rename = "reservationSet", default)]
247    reservation_set: ItemList<Reservation>,
248    #[serde(rename = "nextToken", default)]
249    next_token: Option<String>,
250}
251
252#[derive(serde::Deserialize, Debug)]
253struct Reservation {
254    #[serde(rename = "instancesSet", default)]
255    instances_set: ItemList<Ec2Instance>,
256}
257
258#[derive(serde::Deserialize, Debug)]
259struct Ec2Instance {
260    #[serde(rename = "instanceId", default)]
261    instance_id: String,
262    #[serde(rename = "imageId", default)]
263    image_id: String,
264    #[serde(rename = "instanceState", default)]
265    instance_state: InstanceState,
266    #[serde(rename = "instanceType", default)]
267    instance_type: String,
268    #[serde(rename = "tagSet", default)]
269    tag_set: ItemList<Ec2Tag>,
270    #[serde(rename = "ipAddress", default)]
271    ip_address: Option<String>,
272    #[serde(rename = "privateIpAddress", default)]
273    private_ip_address: Option<String>,
274}
275
276#[derive(serde::Deserialize, Debug, Default)]
277struct InstanceState {
278    #[serde(default)]
279    name: String,
280}
281
282#[derive(serde::Deserialize, Debug)]
283struct Ec2Tag {
284    #[serde(default)]
285    key: String,
286    #[serde(default)]
287    value: String,
288}
289
290#[derive(serde::Deserialize, Debug)]
291struct DescribeImagesResponse {
292    #[serde(rename = "imagesSet", default)]
293    images_set: ItemList<ImageInfo>,
294}
295
296#[derive(serde::Deserialize, Debug)]
297struct ImageInfo {
298    #[serde(rename = "imageId", default)]
299    image_id: String,
300    #[serde(default)]
301    name: String,
302}
303
304// --- EC2 API ---
305
306fn param(key: &str, value: &str) -> (String, String) {
307    (key.to_string(), value.to_string())
308}
309
310/// Make a signed GET request to the EC2 API.
311fn ec2_get(
312    agent: &ureq::Agent,
313    creds: &AwsCredentials,
314    region: &str,
315    params: Vec<(String, String)>,
316) -> Result<String, ProviderError> {
317    let host = format!("ec2.{}.amazonaws.com", region);
318    let epoch = std::time::SystemTime::now()
319        .duration_since(std::time::UNIX_EPOCH)
320        .unwrap_or_default()
321        .as_secs();
322    let (timestamp, datestamp) = format_utc(epoch);
323
324    // Build sorted, URI-encoded query string (SigV4 requires sorted params)
325    let mut sorted: Vec<(String, String)> = params
326        .into_iter()
327        .map(|(k, v)| (uri_encode(&k), uri_encode(&v)))
328        .collect();
329    sorted.sort();
330    let query_string: String = sorted
331        .iter()
332        .map(|(k, v)| format!("{}={}", k, v))
333        .collect::<Vec<_>>()
334        .join("&");
335
336    let auth = sign_request(creds, region, &host, &query_string, &timestamp, &datestamp);
337    let url = format!("https://{}/?{}", host, query_string);
338
339    let mut resp = agent
340        .get(&url)
341        .header("Authorization", &auth)
342        .header("x-amz-date", &timestamp)
343        .call()
344        .map_err(super::map_ureq_error)?;
345
346    resp.body_mut()
347        .read_to_string()
348        .map_err(|e| ProviderError::Parse(e.to_string()))
349}
350
351/// Fetch all non-terminated instances in a region (handles pagination).
352fn describe_instances(
353    agent: &ureq::Agent,
354    creds: &AwsCredentials,
355    region: &str,
356    cancel: &AtomicBool,
357) -> Result<Vec<Ec2Instance>, ProviderError> {
358    let mut all = Vec::new();
359    let mut next_token: Option<String> = None;
360    let mut page = 0usize;
361
362    loop {
363        page += 1;
364        if page > 500 {
365            break;
366        }
367        if cancel.load(Ordering::Relaxed) {
368            return Err(ProviderError::Cancelled);
369        }
370
371        let mut params = vec![
372            param("Action", "DescribeInstances"),
373            param("Version", "2016-11-15"),
374        ];
375        if let Some(ref token) = next_token {
376            params.push(param("NextToken", token));
377        }
378
379        let body = ec2_get(agent, creds, region, params)?;
380        let resp: DescribeInstancesResponse = quick_xml::de::from_str(&body)
381            .map_err(|e| ProviderError::Parse(format!("{}: {}", region, e)))?;
382
383        for reservation in resp.reservation_set.item {
384            for instance in reservation.instances_set.item {
385                if instance.instance_state.name != "terminated"
386                    && instance.instance_state.name != "shutting-down"
387                {
388                    all.push(instance);
389                }
390            }
391        }
392
393        match resp.next_token {
394            Some(t) if !t.is_empty() => next_token = Some(t),
395            _ => break,
396        }
397    }
398
399    Ok(all)
400}
401
402/// Maximum AMI IDs per DescribeImages request to stay within AWS query limits.
403const AMI_BATCH_SIZE: usize = 100;
404
405/// Fetch AMI ID to name mapping (best effort, returns empty map on failure).
406/// Batches requests to stay within AWS API limits.
407fn fetch_image_names(
408    agent: &ureq::Agent,
409    creds: &AwsCredentials,
410    region: &str,
411    image_ids: &[String],
412) -> Result<HashMap<String, String>, ProviderError> {
413    if image_ids.is_empty() {
414        return Ok(HashMap::new());
415    }
416
417    let mut map = HashMap::new();
418    for chunk in image_ids.chunks(AMI_BATCH_SIZE) {
419        let mut params = vec![
420            param("Action", "DescribeImages"),
421            param("Version", "2016-11-15"),
422        ];
423        for (i, id) in chunk.iter().enumerate() {
424            params.push(param(&format!("ImageId.{}", i + 1), id));
425        }
426
427        let body = ec2_get(agent, creds, region, params)?;
428        let resp: DescribeImagesResponse = quick_xml::de::from_str(&body)
429            .map_err(|e| ProviderError::Parse(format!("{}: {}", region, e)))?;
430
431        for image in resp.images_set.item {
432            if !image.name.is_empty() {
433                map.insert(image.image_id, image.name);
434            }
435        }
436    }
437    Ok(map)
438}
439
440/// Extract Name tag value and user tags from an instance's tag set.
441/// Filters out aws:* tags. Returns (name, tags) where tags are values only.
442fn extract_tags(tag_set: &[Ec2Tag]) -> (String, Vec<String>) {
443    let mut name = String::new();
444    let mut tags = Vec::new();
445    for tag in tag_set {
446        if tag.key == "Name" {
447            name = tag.value.clone();
448        } else if !tag.key.starts_with("aws:") && !tag.value.is_empty() {
449            tags.push(tag.value.clone());
450        }
451    }
452    tags.sort();
453    (name, tags)
454}
455
456// --- Provider trait ---
457
458impl Provider for Aws {
459    fn name(&self) -> &str {
460        "aws"
461    }
462
463    fn short_label(&self) -> &str {
464        "aws"
465    }
466
467    fn fetch_hosts_cancellable(
468        &self,
469        token: &str,
470        cancel: &AtomicBool,
471        env: &crate::runtime::env::Env,
472    ) -> Result<Vec<ProviderHost>, ProviderError> {
473        self.fetch_hosts_with_progress(token, cancel, env, &|_| {})
474    }
475
476    fn fetch_hosts_with_progress(
477        &self,
478        token: &str,
479        cancel: &AtomicBool,
480        env: &crate::runtime::env::Env,
481        progress: &dyn Fn(&str),
482    ) -> Result<Vec<ProviderHost>, ProviderError> {
483        if self.regions.is_empty() {
484            return Err(ProviderError::Http(
485                "No AWS regions configured. Add regions in the provider settings.".to_string(),
486            ));
487        }
488
489        let valid_codes: HashSet<&str> = AWS_REGIONS.iter().map(|(c, _)| *c).collect();
490        for region in &self.regions {
491            if !valid_codes.contains(region.as_str()) {
492                return Err(ProviderError::Http(format!(
493                    "Unknown AWS region '{}'. Check your provider settings.",
494                    region
495                )));
496            }
497        }
498
499        let creds = resolve_credentials(token, &self.profile, env)?;
500        let agent = super::http_agent();
501        let total_regions = self.regions.len();
502        let mut all_hosts = Vec::new();
503        let mut failed_regions = 0usize;
504
505        for (i, region) in self.regions.iter().enumerate() {
506            if cancel.load(Ordering::Relaxed) {
507                return Err(ProviderError::Cancelled);
508            }
509
510            progress(&format!(
511                "Fetching {} ({}/{})...",
512                region,
513                i + 1,
514                total_regions
515            ));
516
517            let instances = match describe_instances(&agent, &creds, region, cancel) {
518                Ok(instances) => instances,
519                Err(ProviderError::Cancelled) => return Err(ProviderError::Cancelled),
520                Err(ProviderError::AuthFailed) => return Err(ProviderError::AuthFailed),
521                Err(ProviderError::RateLimited) => return Err(ProviderError::RateLimited),
522                Err(_) => {
523                    failed_regions += 1;
524                    continue;
525                }
526            };
527
528            // Collect unique AMI IDs for OS metadata lookup
529            let ami_ids: Vec<String> = {
530                let mut set = HashSet::new();
531                for inst in &instances {
532                    if !inst.image_id.is_empty() {
533                        set.insert(inst.image_id.clone());
534                    }
535                }
536                set.into_iter().collect()
537            };
538
539            // Fetch AMI names (best effort)
540            let ami_names = if !ami_ids.is_empty() {
541                progress(&format!("Resolving AMIs for {}...", region));
542                fetch_image_names(&agent, &creds, region, &ami_ids).unwrap_or_default()
543            } else {
544                HashMap::new()
545            };
546
547            for instance in instances {
548                let ip = match instance.ip_address {
549                    Some(ref ip) if !ip.is_empty() => ip.clone(),
550                    _ => match instance.private_ip_address {
551                        Some(ref ip) if !ip.is_empty() => ip.clone(),
552                        _ => continue,
553                    },
554                };
555
556                let (name, tags) = extract_tags(&instance.tag_set.item);
557                let name = if name.is_empty() {
558                    instance.instance_id.clone()
559                } else {
560                    name
561                };
562
563                let mut metadata = Vec::new();
564                metadata.push(("region".to_string(), region.clone()));
565                if !instance.instance_type.is_empty() {
566                    metadata.push(("instance".to_string(), instance.instance_type.clone()));
567                }
568                if let Some(os_name) = ami_names.get(&instance.image_id) {
569                    metadata.push(("os".to_string(), os_name.clone()));
570                }
571                if !instance.instance_state.name.is_empty() {
572                    metadata.push(("status".to_string(), instance.instance_state.name.clone()));
573                }
574
575                all_hosts.push(ProviderHost {
576                    server_id: instance.instance_id,
577                    name,
578                    ip,
579                    tags,
580                    metadata,
581                });
582            }
583        }
584
585        // Summary
586        let mut parts = vec![format!("{} instances", all_hosts.len())];
587        if failed_regions > 0 {
588            parts.push(format!(
589                "{} of {} regions failed",
590                failed_regions, total_regions
591            ));
592        }
593        progress(&parts.join(", "));
594
595        if failed_regions > 0 {
596            if all_hosts.is_empty() {
597                return Err(ProviderError::Http(format!(
598                    "All {} regions failed. Check your credentials and region configuration.",
599                    total_regions,
600                )));
601            }
602            return Err(ProviderError::PartialResult {
603                hosts: all_hosts,
604                failures: failed_regions,
605                total: total_regions,
606            });
607        }
608
609        Ok(all_hosts)
610    }
611}
612
613#[cfg(test)]
614#[path = "aws_tests.rs"]
615mod tests;