Skip to main content

purple_ssh/providers/
aws.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use hmac::{Hmac, Mac};
5use sha2::{Digest, Sha256};
6
7use super::{Provider, ProviderError, ProviderHost};
8
9pub struct Aws {
10    pub regions: Vec<String>,
11    pub profile: String,
12}
13
14/// All commonly available AWS regions with display names.
15/// Single source of truth. AWS_REGION_GROUPS references slices of this array.
16pub const AWS_REGIONS: &[(&str, &str)] = &[
17    // Americas (0..8)
18    ("us-east-1", "N. Virginia"),
19    ("us-east-2", "Ohio"),
20    ("us-west-1", "N. California"),
21    ("us-west-2", "Oregon"),
22    ("ca-central-1", "Canada Central"),
23    ("ca-west-1", "Canada West"),
24    ("mx-central-1", "Mexico Central"),
25    ("sa-east-1", "Sao Paulo"),
26    // Europe (8..16)
27    ("eu-west-1", "Ireland"),
28    ("eu-west-2", "London"),
29    ("eu-west-3", "Paris"),
30    ("eu-central-1", "Frankfurt"),
31    ("eu-central-2", "Zurich"),
32    ("eu-south-1", "Milan"),
33    ("eu-south-2", "Spain"),
34    ("eu-north-1", "Stockholm"),
35    // Asia Pacific (16..30)
36    ("ap-northeast-1", "Tokyo"),
37    ("ap-northeast-2", "Seoul"),
38    ("ap-northeast-3", "Osaka"),
39    ("ap-southeast-1", "Singapore"),
40    ("ap-southeast-2", "Sydney"),
41    ("ap-southeast-3", "Jakarta"),
42    ("ap-southeast-4", "Melbourne"),
43    ("ap-southeast-5", "Malaysia"),
44    ("ap-southeast-6", "New Zealand"),
45    ("ap-southeast-7", "Thailand"),
46    ("ap-east-1", "Hong Kong"),
47    ("ap-east-2", "Taipei"),
48    ("ap-south-1", "Mumbai"),
49    ("ap-south-2", "Hyderabad"),
50    // Middle East / Africa (30..34)
51    ("me-south-1", "Bahrain"),
52    ("me-central-1", "UAE"),
53    ("il-central-1", "Tel Aviv"),
54    ("af-south-1", "Cape Town"),
55];
56
57/// Region group labels with start..end indices into AWS_REGIONS.
58pub const AWS_REGION_GROUPS: &[(&str, usize, usize)] = &[
59    ("Americas", 0, 8),
60    ("Europe", 8, 16),
61    ("Asia Pacific", 16, 30),
62    ("Middle East / Africa", 30, 34),
63];
64
65// --- Credentials ---
66
67struct AwsCredentials {
68    access_key: String,
69    secret_key: String,
70}
71
72fn resolve_credentials(
73    token: &str,
74    profile: &str,
75    env: &crate::runtime::env::Env,
76) -> Result<AwsCredentials, ProviderError> {
77    // Profile takes priority: read from ~/.aws/credentials
78    if !profile.is_empty() {
79        return read_credentials_file(profile, env);
80    }
81    // Token field: ACCESS_KEY_ID:SECRET_ACCESS_KEY
82    if let Some((ak, sk)) = token.split_once(':') {
83        if !ak.is_empty() && !sk.is_empty() {
84            return Ok(AwsCredentials {
85                access_key: ak.to_string(),
86                secret_key: sk.to_string(),
87            });
88        }
89    }
90    // Environment variables, from the injected snapshot.
91    if let Some((ak, sk)) = env.aws_credentials() {
92        if !ak.is_empty() && !sk.is_empty() {
93            return Ok(AwsCredentials {
94                access_key: ak.to_string(),
95                secret_key: sk.to_string(),
96            });
97        }
98    }
99    Err(ProviderError::AuthFailed)
100}
101
102/// Parse AWS credentials from INI content (testable without filesystem).
103fn parse_credentials(content: &str, profile: &str) -> Option<AwsCredentials> {
104    let header = format!("[{}]", profile);
105    let mut in_section = false;
106    let mut access_key = String::new();
107    let mut secret_key = String::new();
108
109    for line in content.lines() {
110        let trimmed = line.trim();
111        if trimmed.starts_with('[') {
112            in_section = trimmed == header;
113            continue;
114        }
115        if !in_section {
116            continue;
117        }
118        if let Some((key, value)) = trimmed.split_once('=') {
119            match key.trim() {
120                "aws_access_key_id" => access_key = value.trim().to_string(),
121                "aws_secret_access_key" => secret_key = value.trim().to_string(),
122                _ => {}
123            }
124        }
125    }
126
127    if access_key.is_empty() || secret_key.is_empty() {
128        None
129    } else {
130        Some(AwsCredentials {
131            access_key,
132            secret_key,
133        })
134    }
135}
136
137fn read_credentials_file(
138    profile: &str,
139    env: &crate::runtime::env::Env,
140) -> Result<AwsCredentials, ProviderError> {
141    let path = env
142        .paths()
143        .ok_or(ProviderError::AuthFailed)?
144        .aws_credentials_file();
145    let content = std::fs::read_to_string(&path).map_err(|_| ProviderError::AuthFailed)?;
146    parse_credentials(&content, profile).ok_or(ProviderError::AuthFailed)
147}
148
149// --- SigV4 signing ---
150
151fn hex_encode(bytes: &[u8]) -> String {
152    bytes.iter().map(|b| format!("{:02x}", b)).collect()
153}
154
155fn sha256_hash(data: &[u8]) -> Vec<u8> {
156    let mut hasher = Sha256::new();
157    hasher.update(data);
158    hasher.finalize().to_vec()
159}
160
161fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
162    // INVARIANT: `Hmac::<Sha256>::new_from_slice` only fails when the MAC
163    // implementation rejects the key length. HMAC-SHA256 accepts keys of any
164    // length (RFC 2104 ยง2), so this branch is unreachable for Hmac<Sha256>.
165    let mut mac = Hmac::<Sha256>::new_from_slice(key)
166        .expect("Hmac::<Sha256>::new_from_slice accepts any key length (RFC 2104)");
167    mac.update(data);
168    mac.finalize().into_bytes().to_vec()
169}
170
171/// RFC 3986 URI encoding (delegates to shared implementation).
172fn uri_encode(s: &str) -> String {
173    super::percent_encode(s)
174}
175
176/// Format epoch seconds as (timestamp, datestamp) for SigV4.
177fn format_utc(epoch_secs: u64) -> (String, String) {
178    let d = super::epoch_to_date(epoch_secs);
179    let timestamp = format!(
180        "{:04}{:02}{:02}T{:02}{:02}{:02}Z",
181        d.year, d.month, d.day, d.hours, d.minutes, d.seconds,
182    );
183    let datestamp = format!("{:04}{:02}{:02}", d.year, d.month, d.day);
184    (timestamp, datestamp)
185}
186
187/// Build the SigV4 Authorization header value.
188fn sign_request(
189    creds: &AwsCredentials,
190    region: &str,
191    host: &str,
192    query_string: &str,
193    timestamp: &str,
194    datestamp: &str,
195) -> String {
196    let payload_hash = hex_encode(&sha256_hash(b""));
197    let canonical_headers = format!("host:{}\nx-amz-date:{}\n", host, timestamp);
198    let signed_headers = "host;x-amz-date";
199
200    let canonical_request = format!(
201        "GET\n/\n{}\n{}\n{}\n{}",
202        query_string, canonical_headers, signed_headers, payload_hash
203    );
204
205    let scope = format!("{}/{}/ec2/aws4_request", datestamp, region);
206    let string_to_sign = format!(
207        "AWS4-HMAC-SHA256\n{}\n{}\n{}",
208        timestamp,
209        scope,
210        hex_encode(&sha256_hash(canonical_request.as_bytes())),
211    );
212
213    let k_date = hmac_sha256(
214        format!("AWS4{}", creds.secret_key).as_bytes(),
215        datestamp.as_bytes(),
216    );
217    let k_region = hmac_sha256(&k_date, region.as_bytes());
218    let k_service = hmac_sha256(&k_region, b"ec2");
219    let k_signing = hmac_sha256(&k_service, b"aws4_request");
220    let signature = hex_encode(&hmac_sha256(&k_signing, string_to_sign.as_bytes()));
221
222    format!(
223        "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
224        creds.access_key, scope, signed_headers, signature
225    )
226}
227
228// --- XML response structs ---
229
230/// Generic wrapper for AWS XML lists that use repeated `<item>` elements.
231#[derive(serde::Deserialize, Debug)]
232#[serde(bound(deserialize = "T: serde::Deserialize<'de>"))]
233struct ItemList<T> {
234    #[serde(rename = "item", default = "Vec::new")]
235    item: Vec<T>,
236}
237
238impl<T> Default for ItemList<T> {
239    fn default() -> Self {
240        Self { item: Vec::new() }
241    }
242}
243
244#[derive(serde::Deserialize, Debug)]
245struct DescribeInstancesResponse {
246    #[serde(rename = "reservationSet", default)]
247    reservation_set: ItemList<Reservation>,
248    #[serde(rename = "nextToken", default)]
249    next_token: Option<String>,
250}
251
252#[derive(serde::Deserialize, Debug)]
253struct Reservation {
254    #[serde(rename = "instancesSet", default)]
255    instances_set: ItemList<Ec2Instance>,
256}
257
258#[derive(serde::Deserialize, Debug)]
259struct Ec2Instance {
260    #[serde(rename = "instanceId", default)]
261    instance_id: String,
262    #[serde(rename = "imageId", default)]
263    image_id: String,
264    #[serde(rename = "instanceState", default)]
265    instance_state: InstanceState,
266    #[serde(rename = "instanceType", default)]
267    instance_type: String,
268    #[serde(rename = "tagSet", default)]
269    tag_set: ItemList<Ec2Tag>,
270    #[serde(rename = "ipAddress", default)]
271    ip_address: Option<String>,
272    #[serde(rename = "privateIpAddress", default)]
273    private_ip_address: Option<String>,
274}
275
276#[derive(serde::Deserialize, Debug, Default)]
277struct InstanceState {
278    #[serde(default)]
279    name: String,
280}
281
282#[derive(serde::Deserialize, Debug)]
283struct Ec2Tag {
284    #[serde(default)]
285    key: String,
286    #[serde(default)]
287    value: String,
288}
289
290#[derive(serde::Deserialize, Debug)]
291struct DescribeImagesResponse {
292    #[serde(rename = "imagesSet", default)]
293    images_set: ItemList<ImageInfo>,
294}
295
296#[derive(serde::Deserialize, Debug)]
297struct ImageInfo {
298    #[serde(rename = "imageId", default)]
299    image_id: String,
300    #[serde(default)]
301    name: String,
302}
303
304// --- EC2 API ---
305
306fn param(key: &str, value: &str) -> (String, String) {
307    (key.to_string(), value.to_string())
308}
309
310/// Make a signed GET request to the EC2 API.
311fn ec2_get(
312    agent: &ureq::Agent,
313    creds: &AwsCredentials,
314    region: &str,
315    endpoint: &str,
316    params: Vec<(String, String)>,
317) -> Result<String, ProviderError> {
318    // Host used for SigV4 signing and the request URL. Derived from the
319    // injected endpoint so tests can point the signed request at a mock; the
320    // authority is everything after the scheme (e.g. "ec2.us-east-1.amazonaws.com"
321    // or "127.0.0.1:1234").
322    let host = endpoint
323        .split_once("://")
324        .map(|(_, authority)| authority)
325        .unwrap_or(endpoint)
326        .to_string();
327    let epoch = std::time::SystemTime::now()
328        .duration_since(std::time::UNIX_EPOCH)
329        .unwrap_or_default()
330        .as_secs();
331    let (timestamp, datestamp) = format_utc(epoch);
332
333    // Build sorted, URI-encoded query string (SigV4 requires sorted params)
334    let mut sorted: Vec<(String, String)> = params
335        .into_iter()
336        .map(|(k, v)| (uri_encode(&k), uri_encode(&v)))
337        .collect();
338    sorted.sort();
339    let query_string: String = sorted
340        .iter()
341        .map(|(k, v)| format!("{}={}", k, v))
342        .collect::<Vec<_>>()
343        .join("&");
344
345    let auth = sign_request(creds, region, &host, &query_string, &timestamp, &datestamp);
346    let url = format!("{}/?{}", endpoint, query_string);
347
348    let mut resp = agent
349        .get(&url)
350        .header("Authorization", &auth)
351        .header("x-amz-date", &timestamp)
352        .call()
353        .map_err(super::map_ureq_error)?;
354
355    resp.body_mut()
356        .read_to_string()
357        .map_err(|e| ProviderError::Parse(e.to_string()))
358}
359
360/// Fetch all non-terminated instances in a region (handles pagination).
361fn describe_instances(
362    agent: &ureq::Agent,
363    creds: &AwsCredentials,
364    region: &str,
365    endpoint: &str,
366    cancel: &AtomicBool,
367) -> Result<Vec<Ec2Instance>, ProviderError> {
368    let mut all = Vec::new();
369    let mut next_token: Option<String> = None;
370    let mut page = 0usize;
371
372    loop {
373        page += 1;
374        if page > 500 {
375            break;
376        }
377        if cancel.load(Ordering::Relaxed) {
378            return Err(ProviderError::Cancelled);
379        }
380
381        let mut params = vec![
382            param("Action", "DescribeInstances"),
383            param("Version", "2016-11-15"),
384        ];
385        if let Some(ref token) = next_token {
386            params.push(param("NextToken", token));
387        }
388
389        let body = ec2_get(agent, creds, region, endpoint, params)?;
390        let resp: DescribeInstancesResponse = quick_xml::de::from_str(&body)
391            .map_err(|e| ProviderError::Parse(format!("{}: {}", region, e)))?;
392
393        for reservation in resp.reservation_set.item {
394            for instance in reservation.instances_set.item {
395                if instance.instance_state.name != "terminated"
396                    && instance.instance_state.name != "shutting-down"
397                {
398                    all.push(instance);
399                }
400            }
401        }
402
403        match resp.next_token {
404            Some(t) if !t.is_empty() => next_token = Some(t),
405            _ => break,
406        }
407    }
408
409    Ok(all)
410}
411
412/// Maximum AMI IDs per DescribeImages request to stay within AWS query limits.
413const AMI_BATCH_SIZE: usize = 100;
414
415/// Fetch AMI ID to name mapping (best effort, returns empty map on failure).
416/// Batches requests to stay within AWS API limits.
417fn fetch_image_names(
418    agent: &ureq::Agent,
419    creds: &AwsCredentials,
420    region: &str,
421    endpoint: &str,
422    image_ids: &[String],
423) -> Result<HashMap<String, String>, ProviderError> {
424    if image_ids.is_empty() {
425        return Ok(HashMap::new());
426    }
427
428    let mut map = HashMap::new();
429    for chunk in image_ids.chunks(AMI_BATCH_SIZE) {
430        let mut params = vec![
431            param("Action", "DescribeImages"),
432            param("Version", "2016-11-15"),
433        ];
434        for (i, id) in chunk.iter().enumerate() {
435            params.push(param(&format!("ImageId.{}", i + 1), id));
436        }
437
438        let body = ec2_get(agent, creds, region, endpoint, params)?;
439        let resp: DescribeImagesResponse = quick_xml::de::from_str(&body)
440            .map_err(|e| ProviderError::Parse(format!("{}: {}", region, e)))?;
441
442        for image in resp.images_set.item {
443            if !image.name.is_empty() {
444                map.insert(image.image_id, image.name);
445            }
446        }
447    }
448    Ok(map)
449}
450
451/// Extract Name tag value and user tags from an instance's tag set.
452/// Filters out aws:* tags. Returns (name, tags) where tags are values only.
453fn extract_tags(tag_set: &[Ec2Tag]) -> (String, Vec<String>) {
454    let mut name = String::new();
455    let mut tags = Vec::new();
456    for tag in tag_set {
457        if tag.key == "Name" {
458            name = tag.value.clone();
459        } else if !tag.key.starts_with("aws:") && !tag.value.is_empty() {
460            tags.push(tag.value.clone());
461        }
462    }
463    tags.sort();
464    (name, tags)
465}
466
467// --- Provider trait ---
468
469impl Aws {
470    /// Real EC2 endpoint for a region. Overridable via `fetch_with_endpoint`
471    /// so tests can point the signed request at a mock server.
472    fn region_endpoint(region: &str) -> String {
473        format!("https://ec2.{}.amazonaws.com", region)
474    }
475
476    /// Per-region fetch pipeline against caller-supplied endpoints. Production
477    /// resolves the real EC2 host per region; tests pass a closure returning a
478    /// mock URL so SigV4 signing, DescribeInstances + DescribeImages, XML
479    /// deserialize and `ProviderHost` mapping all run end to end.
480    fn fetch_with_endpoint(
481        &self,
482        resolve_endpoint: impl Fn(&str) -> String,
483        token: &str,
484        cancel: &AtomicBool,
485        env: &crate::runtime::env::Env,
486        progress: &dyn Fn(&str),
487    ) -> Result<Vec<ProviderHost>, ProviderError> {
488        if self.regions.is_empty() {
489            return Err(ProviderError::Http(
490                "No AWS regions configured. Add regions in the provider settings.".to_string(),
491            ));
492        }
493
494        let valid_codes: HashSet<&str> = AWS_REGIONS.iter().map(|(c, _)| *c).collect();
495        for region in &self.regions {
496            if !valid_codes.contains(region.as_str()) {
497                return Err(ProviderError::Http(format!(
498                    "Unknown AWS region '{}'. Check your provider settings.",
499                    region
500                )));
501            }
502        }
503
504        let creds = resolve_credentials(token, &self.profile, env)?;
505        let agent = super::http_agent();
506        let total_regions = self.regions.len();
507        let mut all_hosts = Vec::new();
508        let mut failed_regions = 0usize;
509
510        for (i, region) in self.regions.iter().enumerate() {
511            if cancel.load(Ordering::Relaxed) {
512                return Err(ProviderError::Cancelled);
513            }
514
515            progress(&format!(
516                "Fetching {} ({}/{})...",
517                region,
518                i + 1,
519                total_regions
520            ));
521
522            let endpoint = resolve_endpoint(region);
523            let instances = match describe_instances(&agent, &creds, region, &endpoint, cancel) {
524                Ok(instances) => instances,
525                Err(ProviderError::Cancelled) => return Err(ProviderError::Cancelled),
526                Err(ProviderError::AuthFailed) => return Err(ProviderError::AuthFailed),
527                Err(ProviderError::RateLimited) => return Err(ProviderError::RateLimited),
528                Err(_) => {
529                    failed_regions += 1;
530                    continue;
531                }
532            };
533
534            // Collect unique AMI IDs for OS metadata lookup
535            let ami_ids: Vec<String> = {
536                let mut set = HashSet::new();
537                for inst in &instances {
538                    if !inst.image_id.is_empty() {
539                        set.insert(inst.image_id.clone());
540                    }
541                }
542                set.into_iter().collect()
543            };
544
545            // Fetch AMI names (best effort)
546            let ami_names = if !ami_ids.is_empty() {
547                progress(&format!("Resolving AMIs for {}...", region));
548                fetch_image_names(&agent, &creds, region, &endpoint, &ami_ids).unwrap_or_default()
549            } else {
550                HashMap::new()
551            };
552
553            for instance in instances {
554                let ip = match instance.ip_address {
555                    Some(ref ip) if !ip.is_empty() => ip.clone(),
556                    _ => match instance.private_ip_address {
557                        Some(ref ip) if !ip.is_empty() => ip.clone(),
558                        _ => continue,
559                    },
560                };
561
562                let (name, tags) = extract_tags(&instance.tag_set.item);
563                let name = if name.is_empty() {
564                    instance.instance_id.clone()
565                } else {
566                    name
567                };
568
569                let mut metadata = super::ProviderMetadata::new();
570                metadata.push("region", region.clone());
571                if !instance.instance_type.is_empty() {
572                    metadata.push("instance", instance.instance_type.clone());
573                }
574                if let Some(os_name) = ami_names.get(&instance.image_id) {
575                    metadata.push("os", os_name.clone());
576                }
577                if !instance.instance_state.name.is_empty() {
578                    metadata.push("status", instance.instance_state.name.clone());
579                }
580
581                all_hosts.push(ProviderHost {
582                    server_id: instance.instance_id,
583                    name,
584                    ip,
585                    tags,
586                    metadata: metadata.finish(),
587                });
588            }
589        }
590
591        // Summary
592        let mut parts = vec![format!("{} instances", all_hosts.len())];
593        if failed_regions > 0 {
594            parts.push(format!(
595                "{} of {} regions failed",
596                failed_regions, total_regions
597            ));
598        }
599        progress(&parts.join(", "));
600
601        if failed_regions > 0 {
602            if all_hosts.is_empty() {
603                return Err(ProviderError::Http(format!(
604                    "All {} regions failed. Check your credentials and region configuration.",
605                    total_regions,
606                )));
607            }
608            return Err(ProviderError::PartialResult {
609                hosts: all_hosts,
610                failures: failed_regions,
611                total: total_regions,
612            });
613        }
614
615        Ok(all_hosts)
616    }
617}
618
619impl Provider for Aws {
620    fn name(&self) -> &str {
621        "aws"
622    }
623
624    fn short_label(&self) -> &str {
625        "aws"
626    }
627
628    fn fetch_hosts_cancellable(
629        &self,
630        token: &str,
631        cancel: &AtomicBool,
632        env: &crate::runtime::env::Env,
633    ) -> Result<Vec<ProviderHost>, ProviderError> {
634        self.fetch_hosts_with_progress(token, cancel, env, &|_| {})
635    }
636
637    fn fetch_hosts_with_progress(
638        &self,
639        token: &str,
640        cancel: &AtomicBool,
641        env: &crate::runtime::env::Env,
642        progress: &dyn Fn(&str),
643    ) -> Result<Vec<ProviderHost>, ProviderError> {
644        self.fetch_with_endpoint(Self::region_endpoint, token, cancel, env, progress)
645    }
646}
647
648#[cfg(test)]
649#[path = "aws_tests.rs"]
650mod tests;