Skip to main content

ati/core/
gcs.rs

1//! GCS (Google Cloud Storage) client for the ATI skill registry.
2//!
3//! Uses the GCS JSON API directly via `reqwest` — no additional crate dependencies.
4//! Authentication is via service account JSON (the same `gcp_credentials` key
5//! already stored in the ATI keyring).
6
7use serde::Deserialize;
8use std::collections::HashMap;
9use std::sync::Mutex;
10use thiserror::Error;
11
12use crate::core::skill::{self, SkillMeta};
13
14#[derive(Error, Debug)]
15pub enum GcsError {
16    #[error("HTTP error: {0}")]
17    Http(#[from] reqwest::Error),
18    #[error("Auth error: {0}")]
19    Auth(String),
20    #[error("JSON parse error: {0}")]
21    Json(#[from] serde_json::Error),
22    #[error("JWT signing error: {0}")]
23    Jwt(#[from] jsonwebtoken::errors::Error),
24    #[error("GCS API error ({status}): {message}")]
25    Api { status: u16, message: String },
26    #[error("Invalid UTF-8 in GCS object: {0}")]
27    Utf8(String),
28    #[error("Invalid service account JSON: {0}")]
29    InvalidCredentials(String),
30}
31
32// ---------------------------------------------------------------------------
33// Service account credentials
34// ---------------------------------------------------------------------------
35
36#[derive(Deserialize)]
37struct ServiceAccount {
38    client_email: String,
39    private_key: String,
40    #[serde(default = "default_token_uri")]
41    token_uri: String,
42}
43
44fn default_token_uri() -> String {
45    "https://oauth2.googleapis.com/token".into()
46}
47
48struct CachedToken {
49    token: String,
50    expires_at: u64,
51}
52
53// ---------------------------------------------------------------------------
54// GCS client
55// ---------------------------------------------------------------------------
56
57/// Minimal GCS client using the JSON API. Authenticates via service account JWT.
58pub struct GcsClient {
59    bucket: String,
60    http: reqwest::Client,
61    service_account: ServiceAccount,
62    token: Mutex<Option<CachedToken>>,
63    scope: String,
64}
65
66impl GcsClient {
67    /// Create a new read-only GCS client (used by the skill registry loader).
68    pub fn new(bucket: String, service_account_json: &str) -> Result<Self, GcsError> {
69        Self::new_with_scope(
70            bucket,
71            service_account_json,
72            "https://www.googleapis.com/auth/devstorage.read_only",
73        )
74    }
75
76    /// Create a read/write GCS client (used by the file_manager upload tool).
77    pub fn new_read_write(bucket: String, service_account_json: &str) -> Result<Self, GcsError> {
78        Self::new_with_scope(
79            bucket,
80            service_account_json,
81            "https://www.googleapis.com/auth/devstorage.read_write",
82        )
83    }
84
85    fn new_with_scope(
86        bucket: String,
87        service_account_json: &str,
88        scope: &str,
89    ) -> Result<Self, GcsError> {
90        let sa: ServiceAccount = serde_json::from_str(service_account_json)
91            .map_err(|e| GcsError::InvalidCredentials(e.to_string()))?;
92
93        if sa.client_email.is_empty() || sa.private_key.is_empty() {
94            return Err(GcsError::InvalidCredentials(
95                "client_email and private_key are required".into(),
96            ));
97        }
98
99        let http = reqwest::Client::builder()
100            .timeout(std::time::Duration::from_secs(60))
101            .build()
102            .map_err(GcsError::Http)?;
103
104        Ok(Self {
105            bucket,
106            http,
107            service_account: sa,
108            token: Mutex::new(None),
109            scope: scope.to_string(),
110        })
111    }
112
113    /// Bucket name this client targets.
114    pub fn bucket(&self) -> &str {
115        &self.bucket
116    }
117
118    /// Get a valid access token, refreshing if expired.
119    async fn access_token(&self) -> Result<String, GcsError> {
120        // Check cached token
121        {
122            let guard = self.token.lock().unwrap();
123            if let Some(ref cached) = *guard {
124                let now = std::time::SystemTime::now()
125                    .duration_since(std::time::UNIX_EPOCH)
126                    .unwrap()
127                    .as_secs();
128                if now < cached.expires_at {
129                    return Ok(cached.token.clone());
130                }
131            }
132        }
133
134        // Mint a new token via service account JWT → OAuth2 exchange
135        let now = std::time::SystemTime::now()
136            .duration_since(std::time::UNIX_EPOCH)
137            .unwrap()
138            .as_secs();
139
140        let claims = serde_json::json!({
141            "iss": self.service_account.client_email,
142            "scope": self.scope,
143            "aud": self.service_account.token_uri,
144            "iat": now,
145            "exp": now + 3600,
146        });
147
148        let key =
149            jsonwebtoken::EncodingKey::from_rsa_pem(self.service_account.private_key.as_bytes())
150                .map_err(|e| GcsError::Auth(format!("invalid RSA key: {e}")))?;
151
152        let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256);
153        let assertion = jsonwebtoken::encode(&header, &claims, &key)?;
154
155        // Exchange JWT for access token
156        let resp = self
157            .http
158            .post(&self.service_account.token_uri)
159            .form(&[
160                ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
161                ("assertion", &assertion),
162            ])
163            .send()
164            .await?;
165
166        if !resp.status().is_success() {
167            let status = resp.status().as_u16();
168            let body = resp.text().await.unwrap_or_default();
169            return Err(GcsError::Api {
170                status,
171                message: body,
172            });
173        }
174
175        #[derive(Deserialize)]
176        struct TokenResponse {
177            access_token: String,
178            expires_in: Option<u64>,
179        }
180
181        let token_resp: TokenResponse = resp.json().await?;
182        let expires_at = now + token_resp.expires_in.unwrap_or(3600) - 300; // 5 min buffer
183
184        let access_token = token_resp.access_token.clone();
185
186        // Cache it
187        {
188            let mut guard = self.token.lock().unwrap();
189            *guard = Some(CachedToken {
190                token: token_resp.access_token,
191                expires_at,
192            });
193        }
194
195        Ok(access_token)
196    }
197
198    /// List top-level "directories" (prefixes) in the bucket.
199    /// Returns skill names like `["fal-generate", "compliance-screening", ...]`.
200    pub async fn list_skill_names(&self) -> Result<Vec<String>, GcsError> {
201        let url = format!(
202            "https://storage.googleapis.com/storage/v1/b/{}/o?delimiter=/",
203            self.bucket
204        );
205
206        let resp = self.get_with_retry(&url).await?;
207
208        if !resp.status().is_success() {
209            let status = resp.status().as_u16();
210            let body = resp.text().await.unwrap_or_default();
211            return Err(GcsError::Api {
212                status,
213                message: body,
214            });
215        }
216
217        #[derive(Deserialize)]
218        struct ListResponse {
219            #[serde(default)]
220            prefixes: Vec<String>,
221        }
222
223        let list: ListResponse = resp.json().await?;
224        Ok(list
225            .prefixes
226            .into_iter()
227            .map(|p| p.trim_end_matches('/').to_string())
228            .filter(|p| !p.is_empty())
229            .collect())
230    }
231
232    /// List all objects under a prefix (recursive).
233    /// Returns relative paths like `["SKILL.md", "skill.toml", "scripts/generate.sh"]`.
234    pub async fn list_objects(&self, prefix: &str) -> Result<Vec<String>, GcsError> {
235        let full_prefix = format!("{}/", prefix.trim_end_matches('/'));
236        let mut all_objects = Vec::new();
237        let mut page_token: Option<String> = None;
238
239        loop {
240            let mut url = format!(
241                "https://storage.googleapis.com/storage/v1/b/{}/o?prefix={}",
242                self.bucket,
243                urlencoded(&full_prefix)
244            );
245            if let Some(ref pt) = page_token {
246                url.push_str(&format!("&pageToken={}", urlencoded(pt)));
247            }
248
249            let resp = self.get_with_retry(&url).await?;
250
251            if !resp.status().is_success() {
252                let status = resp.status().as_u16();
253                let body = resp.text().await.unwrap_or_default();
254                return Err(GcsError::Api {
255                    status,
256                    message: body,
257                });
258            }
259
260            #[derive(Deserialize)]
261            struct ListResponse {
262                #[serde(default)]
263                items: Vec<ObjectItem>,
264                #[serde(rename = "nextPageToken")]
265                next_page_token: Option<String>,
266            }
267
268            #[derive(Deserialize)]
269            struct ObjectItem {
270                name: String,
271            }
272
273            let list: ListResponse = resp.json().await?;
274
275            for item in list.items {
276                // Strip the prefix to get relative path
277                if let Some(rel) = item.name.strip_prefix(&full_prefix) {
278                    if !rel.is_empty() {
279                        all_objects.push(rel.to_string());
280                    }
281                }
282            }
283
284            match list.next_page_token {
285                Some(pt) => page_token = Some(pt),
286                None => break,
287            }
288        }
289
290        Ok(all_objects)
291    }
292
293    /// Read a single object as bytes.
294    pub async fn get_object(&self, path: &str) -> Result<Vec<u8>, GcsError> {
295        let url = format!(
296            "https://storage.googleapis.com/storage/v1/b/{}/o/{}?alt=media",
297            self.bucket,
298            urlencoded(path)
299        );
300
301        let resp = self.get_with_retry(&url).await?;
302
303        if !resp.status().is_success() {
304            let status = resp.status().as_u16();
305            let body = resp.text().await.unwrap_or_default();
306            return Err(GcsError::Api {
307                status,
308                message: body,
309            });
310        }
311
312        Ok(resp.bytes().await?.to_vec())
313    }
314
315    /// Read a single object as UTF-8 text.
316    pub async fn get_object_text(&self, path: &str) -> Result<String, GcsError> {
317        let bytes = self.get_object(path).await?;
318        String::from_utf8(bytes).map_err(|e| GcsError::Utf8(e.to_string()))
319    }
320
321    /// Upload bytes to `<bucket>/<object_name>` using the GCS JSON simple upload API.
322    /// Returns the public-style URL `https://storage.googleapis.com/<bucket>/<object_name>`.
323    /// The object is *not* made public — the URL only resolves if the bucket grants
324    /// public read access, which is the proxy-operator's responsibility to configure.
325    pub async fn upload_object(
326        &self,
327        object_name: &str,
328        bytes: Vec<u8>,
329        content_type: &str,
330    ) -> Result<String, GcsError> {
331        let url = format!(
332            "https://storage.googleapis.com/upload/storage/v1/b/{}/o?uploadType=media&name={}",
333            self.bucket,
334            urlencoded(object_name)
335        );
336
337        // Retry 429/5xx up to 3 times with exponential backoff — matches the
338        // pattern used by `get_with_retry`. Uploads are idempotent with the
339        // JSON simple-upload API (each call fully replaces the object at
340        // `name=`), so retrying on transient failures is safe.
341        //
342        // `bytes::Bytes` is Arc-backed, so cloning across retries is O(1) —
343        // the alternative is a 1 GB memcpy per attempt on large uploads.
344        let body = bytes::Bytes::from(bytes);
345        let mut last_err: Option<GcsError> = None;
346        for attempt in 0..3 {
347            let token = self.access_token().await?;
348            match self
349                .http
350                .post(&url)
351                .bearer_auth(&token)
352                .header(reqwest::header::CONTENT_TYPE, content_type)
353                .body(body.clone())
354                .send()
355                .await
356            {
357                Ok(resp) => {
358                    let status = resp.status().as_u16();
359                    if status == 429 || status >= 500 {
360                        let body = resp.text().await.unwrap_or_default();
361                        last_err = Some(GcsError::Api {
362                            status,
363                            message: body,
364                        });
365                        let delay = std::time::Duration::from_millis(500 * (1 << attempt));
366                        tokio::time::sleep(delay).await;
367                        continue;
368                    }
369                    if !resp.status().is_success() {
370                        let body = resp.text().await.unwrap_or_default();
371                        return Err(GcsError::Api {
372                            status,
373                            message: body,
374                        });
375                    }
376                    // Success — return the canonical public URL. Path-style so
377                    // object names with `/` segments round-trip cleanly.
378                    return Ok(format!(
379                        "https://storage.googleapis.com/{}/{}",
380                        self.bucket,
381                        object_name
382                            .split('/')
383                            .map(percent_encode_segment)
384                            .collect::<Vec<_>>()
385                            .join("/")
386                    ));
387                }
388                Err(e) => {
389                    last_err = Some(GcsError::Http(e));
390                    let delay = std::time::Duration::from_millis(500 * (1 << attempt));
391                    tokio::time::sleep(delay).await;
392                }
393            }
394        }
395        Err(last_err.expect("loop body sets last_err on every failure path"))
396    }
397
398    /// Retry-aware wrapper for GET requests. Retries on 429/5xx up to 3 times with backoff.
399    async fn get_with_retry(&self, url: &str) -> Result<reqwest::Response, GcsError> {
400        let mut last_err = None;
401        for attempt in 0..3 {
402            let token = self.access_token().await?;
403            match self.http.get(url).bearer_auth(&token).send().await {
404                Ok(resp) => {
405                    let status = resp.status().as_u16();
406                    if status == 429 || status >= 500 {
407                        let body = resp.text().await.unwrap_or_default();
408                        last_err = Some(GcsError::Api {
409                            status,
410                            message: body,
411                        });
412                        let delay = std::time::Duration::from_millis(500 * (1 << attempt));
413                        tokio::time::sleep(delay).await;
414                        continue;
415                    }
416                    return Ok(resp);
417                }
418                Err(e) => {
419                    last_err = Some(GcsError::Http(e));
420                    let delay = std::time::Duration::from_millis(500 * (1 << attempt));
421                    tokio::time::sleep(delay).await;
422                }
423            }
424        }
425        Err(last_err.unwrap())
426    }
427}
428
429/// Minimal URL encoding for GCS object names.
430fn urlencoded(s: &str) -> String {
431    s.replace('%', "%25")
432        .replace(' ', "%20")
433        .replace('/', "%2F")
434        .replace('?', "%3F")
435        .replace('#', "%23")
436        .replace('&', "%26")
437        .replace('=', "%3D")
438}
439
440// `percent_encode_segment` used to live here; it was an exact duplicate of
441// `core::http::percent_encode_path_segment`. Import that instead.
442use crate::core::http::percent_encode_path_segment as percent_encode_segment;
443
444// ---------------------------------------------------------------------------
445// GCS skill source — loads all skills from a bucket into memory
446// ---------------------------------------------------------------------------
447
448/// Skills loaded from a GCS bucket, with all files cached in memory.
449pub struct GcsSkillSource {
450    /// Parsed skill metadata.
451    pub skills: Vec<SkillMeta>,
452    /// All files keyed by (skill_name, relative_path).
453    pub files: HashMap<(String, String), Vec<u8>>,
454}
455
456impl GcsSkillSource {
457    /// Load all skills from a GCS bucket concurrently.
458    ///
459    /// Enumerates top-level "directories" as skill names, then fetches
460    /// all files in each skill directory with bounded concurrency.
461    pub async fn load(client: &GcsClient) -> Result<Self, GcsError> {
462        use futures::stream::{self, StreamExt};
463
464        let skill_names = client.list_skill_names().await?;
465        tracing::debug!(count = skill_names.len(), "discovered skills in GCS bucket");
466
467        // Load all skills concurrently (up to 50 at a time)
468        let results: Vec<_> = stream::iter(skill_names)
469            .map(|name| async move { Self::load_one_skill(client, &name).await })
470            .buffer_unordered(50)
471            .collect()
472            .await;
473
474        let mut skills = Vec::new();
475        let mut files: HashMap<(String, String), Vec<u8>> = HashMap::new();
476
477        for result in results {
478            match result {
479                Ok((meta, skill_files)) => {
480                    skills.push(meta);
481                    files.extend(skill_files);
482                }
483                Err((name, e)) => {
484                    tracing::warn!(skill = %name, error = %e, "failed to load GCS skill");
485                }
486            }
487        }
488
489        Ok(GcsSkillSource { skills, files })
490    }
491
492    /// Load a single skill: list its files, fetch them concurrently, parse metadata.
493    async fn load_one_skill(
494        client: &GcsClient,
495        name: &str,
496    ) -> Result<(SkillMeta, Vec<((String, String), Vec<u8>)>), (String, String)> {
497        use futures::stream::{self, StreamExt};
498
499        let objects = client
500            .list_objects(name)
501            .await
502            .map_err(|e| (name.to_string(), e.to_string()))?;
503
504        // Fetch all files in this skill concurrently
505        let file_results: Vec<_> = stream::iter(objects)
506            .map(|rel_path| {
507                let full_path = format!("{}/{}", name, rel_path);
508                let name = name.to_string();
509                async move {
510                    match client.get_object(&full_path).await {
511                        Ok(data) => Some(((name, rel_path), data)),
512                        Err(e) => {
513                            tracing::warn!(path = %full_path, error = %e, "failed to fetch file");
514                            None
515                        }
516                    }
517                }
518            })
519            .buffer_unordered(20)
520            .collect()
521            .await;
522
523        let file_entries: Vec<((String, String), Vec<u8>)> =
524            file_results.into_iter().flatten().collect();
525
526        // Parse metadata
527        let skill_md = file_entries
528            .iter()
529            .find(|((_, p), _)| p == "SKILL.md")
530            .and_then(|(_, data)| std::str::from_utf8(data).ok())
531            .unwrap_or("");
532
533        let skill_toml = file_entries
534            .iter()
535            .find(|((_, p), _)| p == "skill.toml")
536            .and_then(|(_, data)| std::str::from_utf8(data).ok());
537
538        let meta = skill::parse_skill_metadata(name, skill_md, skill_toml)
539            .map_err(|e| (name.to_string(), e.to_string()))?;
540
541        Ok((meta, file_entries))
542    }
543
544    /// Number of skills loaded.
545    pub fn skill_count(&self) -> usize {
546        self.skills.len()
547    }
548}