Skip to main content

purple_ssh/providers/
gcp.rs

1use std::collections::HashSet;
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use base64::Engine;
5use base64::engine::general_purpose::URL_SAFE_NO_PAD;
6use serde::Deserialize;
7
8use super::{Provider, ProviderError, ProviderHost, map_ureq_error};
9
10pub struct Gcp {
11    pub zones: Vec<String>,
12    pub project: String,
13}
14
15/// All GCP Compute Engine zones with display names.
16/// Single source of truth. GCP_ZONE_GROUPS references slices of this array.
17/// This list only affects the TUI zone picker. Unlisted zones are still synced
18/// when no zone filter is configured (empty = all zones).
19pub const GCP_ZONES: &[(&str, &str)] = &[
20    // US Central (0..4)
21    ("us-central1-a", "Iowa A"),
22    ("us-central1-b", "Iowa B"),
23    ("us-central1-c", "Iowa C"),
24    ("us-central1-f", "Iowa F"),
25    // US East (4..13)
26    ("us-east1-b", "South Carolina B"),
27    ("us-east1-c", "South Carolina C"),
28    ("us-east1-d", "South Carolina D"),
29    ("us-east4-a", "Virginia A"),
30    ("us-east4-b", "Virginia B"),
31    ("us-east4-c", "Virginia C"),
32    ("us-east5-a", "Columbus A"),
33    ("us-east5-b", "Columbus B"),
34    ("us-east5-c", "Columbus C"),
35    // US South (13..16)
36    ("us-south1-a", "Dallas A"),
37    ("us-south1-b", "Dallas B"),
38    ("us-south1-c", "Dallas C"),
39    // US West (16..28)
40    ("us-west1-a", "Oregon A"),
41    ("us-west1-b", "Oregon B"),
42    ("us-west1-c", "Oregon C"),
43    ("us-west2-a", "Los Angeles A"),
44    ("us-west2-b", "Los Angeles B"),
45    ("us-west2-c", "Los Angeles C"),
46    ("us-west3-a", "Salt Lake City A"),
47    ("us-west3-b", "Salt Lake City B"),
48    ("us-west3-c", "Salt Lake City C"),
49    ("us-west4-a", "Las Vegas A"),
50    ("us-west4-b", "Las Vegas B"),
51    ("us-west4-c", "Las Vegas C"),
52    // North America (28..37)
53    ("northamerica-northeast1-a", "Montreal A"),
54    ("northamerica-northeast1-b", "Montreal B"),
55    ("northamerica-northeast1-c", "Montreal C"),
56    ("northamerica-northeast2-a", "Toronto A"),
57    ("northamerica-northeast2-b", "Toronto B"),
58    ("northamerica-northeast2-c", "Toronto C"),
59    ("northamerica-south1-a", "Queretaro A"),
60    ("northamerica-south1-b", "Queretaro B"),
61    ("northamerica-south1-c", "Queretaro C"),
62    // South America (37..43)
63    ("southamerica-east1-a", "Sao Paulo A"),
64    ("southamerica-east1-b", "Sao Paulo B"),
65    ("southamerica-east1-c", "Sao Paulo C"),
66    ("southamerica-west1-a", "Santiago A"),
67    ("southamerica-west1-b", "Santiago B"),
68    ("southamerica-west1-c", "Santiago C"),
69    // Europe West (43..70)
70    ("europe-west1-b", "Belgium B"),
71    ("europe-west1-c", "Belgium C"),
72    ("europe-west1-d", "Belgium D"),
73    ("europe-west2-a", "London A"),
74    ("europe-west2-b", "London B"),
75    ("europe-west2-c", "London C"),
76    ("europe-west3-a", "Frankfurt A"),
77    ("europe-west3-b", "Frankfurt B"),
78    ("europe-west3-c", "Frankfurt C"),
79    ("europe-west4-a", "Netherlands A"),
80    ("europe-west4-b", "Netherlands B"),
81    ("europe-west4-c", "Netherlands C"),
82    ("europe-west6-a", "Zurich A"),
83    ("europe-west6-b", "Zurich B"),
84    ("europe-west6-c", "Zurich C"),
85    ("europe-west8-a", "Milan A"),
86    ("europe-west8-b", "Milan B"),
87    ("europe-west8-c", "Milan C"),
88    ("europe-west9-a", "Paris A"),
89    ("europe-west9-b", "Paris B"),
90    ("europe-west9-c", "Paris C"),
91    ("europe-west10-a", "Berlin A"),
92    ("europe-west10-b", "Berlin B"),
93    ("europe-west10-c", "Berlin C"),
94    ("europe-west12-a", "Turin A"),
95    ("europe-west12-b", "Turin B"),
96    ("europe-west12-c", "Turin C"),
97    // Europe Other (70..82)
98    ("europe-north1-a", "Finland A"),
99    ("europe-north1-b", "Finland B"),
100    ("europe-north1-c", "Finland C"),
101    ("europe-north2-a", "Stockholm A"),
102    ("europe-north2-b", "Stockholm B"),
103    ("europe-north2-c", "Stockholm C"),
104    ("europe-central2-a", "Warsaw A"),
105    ("europe-central2-b", "Warsaw B"),
106    ("europe-central2-c", "Warsaw C"),
107    ("europe-southwest1-a", "Madrid A"),
108    ("europe-southwest1-b", "Madrid B"),
109    ("europe-southwest1-c", "Madrid C"),
110    // Asia East (82..88)
111    ("asia-east1-a", "Taiwan A"),
112    ("asia-east1-b", "Taiwan B"),
113    ("asia-east1-c", "Taiwan C"),
114    ("asia-east2-a", "Hong Kong A"),
115    ("asia-east2-b", "Hong Kong B"),
116    ("asia-east2-c", "Hong Kong C"),
117    // Asia Northeast (88..97)
118    ("asia-northeast1-a", "Tokyo A"),
119    ("asia-northeast1-b", "Tokyo B"),
120    ("asia-northeast1-c", "Tokyo C"),
121    ("asia-northeast2-a", "Osaka A"),
122    ("asia-northeast2-b", "Osaka B"),
123    ("asia-northeast2-c", "Osaka C"),
124    ("asia-northeast3-a", "Seoul A"),
125    ("asia-northeast3-b", "Seoul B"),
126    ("asia-northeast3-c", "Seoul C"),
127    // Asia South (97..103)
128    ("asia-south1-a", "Mumbai A"),
129    ("asia-south1-b", "Mumbai B"),
130    ("asia-south1-c", "Mumbai C"),
131    ("asia-south2-a", "Delhi A"),
132    ("asia-south2-b", "Delhi B"),
133    ("asia-south2-c", "Delhi C"),
134    // Asia Southeast (103..109)
135    ("asia-southeast1-a", "Singapore A"),
136    ("asia-southeast1-b", "Singapore B"),
137    ("asia-southeast1-c", "Singapore C"),
138    ("asia-southeast2-a", "Jakarta A"),
139    ("asia-southeast2-b", "Jakarta B"),
140    ("asia-southeast2-c", "Jakarta C"),
141    // Australia (109..115)
142    ("australia-southeast1-a", "Sydney A"),
143    ("australia-southeast1-b", "Sydney B"),
144    ("australia-southeast1-c", "Sydney C"),
145    ("australia-southeast2-a", "Melbourne A"),
146    ("australia-southeast2-b", "Melbourne B"),
147    ("australia-southeast2-c", "Melbourne C"),
148    // Middle East (115..124)
149    ("me-west1-a", "Tel Aviv A"),
150    ("me-west1-b", "Tel Aviv B"),
151    ("me-west1-c", "Tel Aviv C"),
152    ("me-central1-a", "Doha A"),
153    ("me-central1-b", "Doha B"),
154    ("me-central1-c", "Doha C"),
155    ("me-central2-a", "Dammam A"),
156    ("me-central2-b", "Dammam B"),
157    ("me-central2-c", "Dammam C"),
158    // Africa (124..127)
159    ("africa-south1-a", "Johannesburg A"),
160    ("africa-south1-b", "Johannesburg B"),
161    ("africa-south1-c", "Johannesburg C"),
162];
163
164/// Zone group labels with start..end indices into GCP_ZONES.
165pub const GCP_ZONE_GROUPS: &[(&str, usize, usize)] = &[
166    ("US Central", 0, 4),
167    ("US East", 4, 13),
168    ("US South", 13, 16),
169    ("US West", 16, 28),
170    ("North America", 28, 37),
171    ("South America", 37, 43),
172    ("Europe West", 43, 70),
173    ("Europe Other", 70, 82),
174    ("Asia East", 82, 88),
175    ("Asia Northeast", 88, 97),
176    ("Asia South", 97, 103),
177    ("Asia Southeast", 103, 109),
178    ("Australia", 109, 115),
179    ("Middle East", 115, 124),
180    ("Africa", 124, 127),
181];
182
183// --- Serde response models ---
184
185#[derive(Deserialize)]
186struct AggregatedListResponse {
187    #[serde(default)]
188    items: std::collections::HashMap<String, InstancesScopedList>,
189    #[serde(rename = "nextPageToken")]
190    next_page_token: Option<String>,
191}
192
193#[derive(Deserialize)]
194struct InstancesScopedList {
195    #[serde(default)]
196    instances: Vec<GcpInstance>,
197}
198
199#[derive(Deserialize)]
200struct GcpInstance {
201    id: String,
202    name: String,
203    #[serde(default)]
204    status: String,
205    #[serde(rename = "machineType", default)]
206    machine_type: String,
207    #[serde(rename = "networkInterfaces", default)]
208    network_interfaces: Vec<NetworkInterface>,
209    #[serde(default)]
210    disks: Vec<Disk>,
211    #[serde(default)]
212    tags: Option<GcpTags>,
213    #[serde(default)]
214    labels: Option<std::collections::HashMap<String, String>>,
215    #[serde(default)]
216    zone: String,
217}
218
219#[derive(Deserialize)]
220struct NetworkInterface {
221    #[serde(rename = "accessConfigs", default)]
222    access_configs: Vec<AccessConfig>,
223    #[serde(rename = "networkIP", default)]
224    network_ip: String,
225    #[serde(rename = "ipv6AccessConfigs", default)]
226    ipv6_access_configs: Vec<Ipv6AccessConfig>,
227}
228
229#[derive(Deserialize)]
230struct AccessConfig {
231    #[serde(rename = "natIP", default)]
232    nat_ip: String,
233}
234
235#[derive(Deserialize)]
236struct Ipv6AccessConfig {
237    #[serde(rename = "externalIpv6", default)]
238    external_ipv6: String,
239}
240
241#[derive(Deserialize)]
242struct Disk {
243    #[serde(default)]
244    licenses: Vec<String>,
245}
246
247#[derive(Deserialize)]
248struct GcpTags {
249    #[serde(default)]
250    items: Vec<String>,
251}
252
253/// Extract the last segment of a URL path (e.g. ".../zones/us-central1-a" -> "us-central1-a").
254fn last_url_segment(url: &str) -> &str {
255    url.rsplit('/').next().unwrap_or("")
256}
257
258/// Select the best IP for an instance.
259/// Prefers external (natIP) > internal (networkIP) > external IPv6.
260fn select_ip(instance: &GcpInstance) -> Option<String> {
261    for ni in &instance.network_interfaces {
262        for ac in &ni.access_configs {
263            if !ac.nat_ip.is_empty() {
264                return Some(ac.nat_ip.clone());
265            }
266        }
267    }
268    for ni in &instance.network_interfaces {
269        if !ni.network_ip.is_empty() {
270            return Some(ni.network_ip.clone());
271        }
272    }
273    for ni in &instance.network_interfaces {
274        for v6 in &ni.ipv6_access_configs {
275            if !v6.external_ipv6.is_empty() {
276                return Some(v6.external_ipv6.clone());
277            }
278        }
279    }
280    None
281}
282
283/// Build metadata key-value pairs for an instance.
284fn build_metadata(instance: &GcpInstance) -> Vec<(String, String)> {
285    let mut metadata = super::ProviderMetadata::new();
286    let zone = last_url_segment(&instance.zone);
287    if !zone.is_empty() {
288        metadata.push("zone", zone);
289    }
290    let machine = last_url_segment(&instance.machine_type);
291    if !machine.is_empty() {
292        metadata.push("machine", machine);
293    }
294    // OS from first disk's first license (e.g. "debian-11" from license URL)
295    if let Some(disk) = instance.disks.first() {
296        if let Some(license) = disk.licenses.first() {
297            let os = last_url_segment(license);
298            if !os.is_empty() {
299                metadata.push("os", os);
300            }
301        }
302    }
303    if !instance.status.is_empty() {
304        metadata.push("status", instance.status.clone());
305    }
306    metadata.finish()
307}
308
309/// Build tags from GCP tags and labels.
310fn build_tags(instance: &GcpInstance) -> Vec<String> {
311    let mut tags = Vec::new();
312    if let Some(ref t) = instance.tags {
313        tags.extend(t.items.clone());
314    }
315    if let Some(ref labels) = instance.labels {
316        for (k, v) in labels {
317            if v.is_empty() {
318                tags.push(k.clone());
319            } else {
320                tags.push(format!("{}:{}", k, v));
321            }
322        }
323    }
324    tags
325}
326
327/// Detect whether a token string is a path to a service account JSON key file.
328/// Checks for .json extension (case-insensitive).
329fn is_json_key_file(token: &str) -> bool {
330    token.to_ascii_lowercase().ends_with(".json")
331}
332
333/// Service account key file fields we need.
334#[derive(Deserialize)]
335struct ServiceAccountKey {
336    client_email: String,
337    private_key: String,
338}
339
340/// Create a JWT and exchange it for an access token via Google's OAuth2 endpoint.
341fn resolve_service_account_token(path: &str) -> Result<String, ProviderError> {
342    let content = std::fs::read_to_string(path)
343        .map_err(|e| ProviderError::Http(format!("Failed to read key file {}: {}", path, e)))?;
344    let key: ServiceAccountKey = serde_json::from_str(&content)
345        .map_err(|e| ProviderError::Http(format!("Failed to parse key file: {}", e)))?;
346
347    let now = std::time::SystemTime::now()
348        .duration_since(std::time::UNIX_EPOCH)
349        .unwrap_or_default()
350        .as_secs();
351
352    let header = r#"{"alg":"RS256","typ":"JWT"}"#;
353    let claims = serde_json::json!({
354        "iss": key.client_email,
355        "scope": "https://www.googleapis.com/auth/compute.readonly",
356        "aud": "https://oauth2.googleapis.com/token",
357        "iat": now,
358        "exp": now + 3600
359    });
360    let claims_str = claims.to_string();
361
362    let header_b64 = URL_SAFE_NO_PAD.encode(header.as_bytes());
363    let claims_b64 = URL_SAFE_NO_PAD.encode(claims_str.as_bytes());
364    let signing_input = format!("{}.{}", header_b64, claims_b64);
365
366    // Parse the PEM private key and sign with RSA-SHA256
367    let der = rsa::pkcs8::DecodePrivateKey::from_pkcs8_pem(&key.private_key)
368        .map_err(|e| ProviderError::Http(format!("Failed to parse private key: {}", e)))?;
369    let signing_key = rsa::pkcs1v15::SigningKey::<sha2::Sha256>::new(der);
370    use rsa::signature::{SignatureEncoding, Signer};
371    let signature = signing_key.sign(signing_input.as_bytes());
372    let sig_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
373
374    let jwt = format!("{}.{}", signing_input, sig_b64);
375
376    // Exchange JWT for access token
377    let agent = super::http_agent();
378    let mut resp = agent
379        .post("https://oauth2.googleapis.com/token")
380        .send_form([
381            ("grant_type", "urn:ietf:params:oauth:grant_type:jwt-bearer"),
382            ("assertion", jwt.as_str()),
383        ])
384        .map_err(map_ureq_error)?;
385
386    #[derive(Deserialize)]
387    struct TokenResponse {
388        access_token: String,
389    }
390
391    let token_resp: TokenResponse = resp
392        .body_mut()
393        .read_json()
394        .map_err(|e| ProviderError::Parse(format!("Token response: {}", e)))?;
395
396    Ok(token_resp.access_token)
397}
398
399/// Resolve token: if it's a path to a JSON key file, exchange it for an access token.
400/// Otherwise, use it as a raw access token.
401fn resolve_token(token: &str) -> Result<String, ProviderError> {
402    if is_json_key_file(token) {
403        resolve_service_account_token(token)
404    } else {
405        Ok(token.to_string())
406    }
407}
408
409/// Percent-encode a page token for use in a URL query parameter (delegates to shared implementation).
410fn url_encode(s: &str) -> String {
411    super::percent_encode(s)
412}
413
414impl Gcp {
415    /// Real Compute API host. Overridable via `fetch_with_endpoint` so tests
416    /// can drive the full list pipeline against a mock server.
417    const API_BASE: &'static str = "https://compute.googleapis.com";
418
419    /// Fetch instances against an explicit Compute API base. Production passes
420    /// `API_BASE`; tests pass a mock URL (and a plain bearer token so the
421    /// OAuth exchange is skipped), exercising URL construction, the auth
422    /// header, pagination, partial-failure handling and `ProviderHost` mapping.
423    fn fetch_with_endpoint(
424        &self,
425        api_base: &str,
426        token: &str,
427        cancel: &AtomicBool,
428        _env: &crate::runtime::env::Env,
429        progress: &dyn Fn(&str),
430    ) -> Result<Vec<ProviderHost>, ProviderError> {
431        if self.project.is_empty() {
432            return Err(ProviderError::Http(
433                "No GCP project configured. Set the Project ID in the provider settings."
434                    .to_string(),
435            ));
436        }
437
438        // Validate project ID format: lowercase letters, digits, hyphens, dots and colons
439        // (dots and colons for domain-scoped projects like example.com:my-project)
440        if !self
441            .project
442            .chars()
443            .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || matches!(c, '-' | '.' | ':'))
444        {
445            return Err(ProviderError::Http(format!(
446                "Invalid GCP project ID '{}'. Must contain only lowercase letters, digits, hyphens, dots and colons.",
447                self.project
448            )));
449        }
450
451        progress("Authenticating...");
452        let access_token = resolve_token(token)?;
453
454        if cancel.load(Ordering::Relaxed) {
455            return Err(ProviderError::Cancelled);
456        }
457
458        let zone_filter: HashSet<&str> = self.zones.iter().map(|s| s.as_str()).collect();
459        let agent = super::http_agent();
460        let mut all_hosts = Vec::new();
461        let mut page_token: Option<String> = None;
462
463        for page in 0u32.. {
464            if cancel.load(Ordering::Relaxed) {
465                return Err(ProviderError::Cancelled);
466            }
467
468            // Safety guard: prevent infinite pagination loops
469            if page > 500 {
470                break;
471            }
472
473            let mut url = format!(
474                "{}/compute/v1/projects/{}/aggregated/instances?maxResults=500&returnPartialSuccess=true",
475                api_base, self.project
476            );
477            if let Some(ref pt) = page_token {
478                url.push_str(&format!("&pageToken={}", url_encode(pt)));
479            }
480
481            progress(&format!(
482                "Fetching instances ({} so far)...",
483                all_hosts.len()
484            ));
485
486            let mut response = match agent
487                .get(&url)
488                .header("Authorization", &super::bearer_auth(&access_token))
489                .call()
490            {
491                Ok(r) => r,
492                Err(e) => {
493                    let err = map_ureq_error(e);
494                    // If we already fetched some hosts, return a partial result
495                    if !all_hosts.is_empty() {
496                        let fetched = all_hosts.len();
497                        progress(&format!("{} instances, page {} failed", fetched, page + 1));
498                        return Err(ProviderError::PartialResult {
499                            hosts: all_hosts,
500                            failures: 1,
501                            total: page as usize + 1,
502                        });
503                    }
504                    return Err(err);
505                }
506            };
507
508            let resp: AggregatedListResponse = match response.body_mut().read_json() {
509                Ok(r) => r,
510                Err(e) => {
511                    if !all_hosts.is_empty() {
512                        let fetched = all_hosts.len();
513                        progress(&format!(
514                            "{} instances, page {} failed to parse",
515                            fetched,
516                            page + 1
517                        ));
518                        return Err(ProviderError::PartialResult {
519                            hosts: all_hosts,
520                            failures: 1,
521                            total: page as usize + 1,
522                        });
523                    }
524                    return Err(ProviderError::Parse(e.to_string()));
525                }
526            };
527
528            for (scope_key, scoped_list) in &resp.items {
529                // scope_key is like "zones/us-central1-a"
530                let zone = last_url_segment(scope_key);
531
532                // Client-side zone filter (empty = all zones)
533                if !zone_filter.is_empty() && !zone_filter.contains(zone) {
534                    continue;
535                }
536
537                for instance in &scoped_list.instances {
538                    if let Some(ip) = select_ip(instance) {
539                        all_hosts.push(ProviderHost {
540                            server_id: instance.id.clone(),
541                            name: instance.name.clone(),
542                            ip,
543                            tags: build_tags(instance),
544                            metadata: build_metadata(instance),
545                        });
546                    }
547                }
548            }
549
550            match resp.next_page_token {
551                Some(ref t) if !t.is_empty() => page_token = Some(t.clone()),
552                _ => break,
553            }
554        }
555
556        progress(&format!("{} instances", all_hosts.len()));
557        Ok(all_hosts)
558    }
559}
560
561impl Provider for Gcp {
562    fn name(&self) -> &str {
563        "gcp"
564    }
565
566    fn short_label(&self) -> &str {
567        "gcp"
568    }
569
570    fn fetch_hosts_cancellable(
571        &self,
572        token: &str,
573        cancel: &AtomicBool,
574        _env: &crate::runtime::env::Env,
575    ) -> Result<Vec<ProviderHost>, ProviderError> {
576        self.fetch_hosts_with_progress(token, cancel, _env, &|_| {})
577    }
578
579    fn fetch_hosts_with_progress(
580        &self,
581        token: &str,
582        cancel: &AtomicBool,
583        env: &crate::runtime::env::Env,
584        progress: &dyn Fn(&str),
585    ) -> Result<Vec<ProviderHost>, ProviderError> {
586        self.fetch_with_endpoint(Self::API_BASE, token, cancel, env, progress)
587    }
588}
589
590#[cfg(test)]
591#[path = "gcp_tests.rs"]
592mod tests;