Skip to main content

ati/core/
gcs.rs

1//! GCS (Google Cloud Storage) client for the ATI skill registry.
2//!
3//! Uses the GCS JSON API directly via `reqwest` — no additional crate dependencies.
4//! Authentication is via service account JSON (the same `gcp_credentials` key
5//! already stored in the ATI keyring).
6
7use serde::Deserialize;
8use std::collections::HashMap;
9use std::sync::Mutex;
10use thiserror::Error;
11
12use crate::core::skill::{self, SkillMeta};
13
14#[derive(Error, Debug)]
15pub enum GcsError {
16    #[error("HTTP error: {0}")]
17    Http(#[from] reqwest::Error),
18    #[error("Auth error: {0}")]
19    Auth(String),
20    #[error("JSON parse error: {0}")]
21    Json(#[from] serde_json::Error),
22    #[error("JWT signing error: {0}")]
23    Jwt(#[from] jsonwebtoken::errors::Error),
24    #[error("GCS API error ({status}): {message}")]
25    Api { status: u16, message: String },
26    #[error("Invalid UTF-8 in GCS object: {0}")]
27    Utf8(String),
28    #[error("Invalid service account JSON: {0}")]
29    InvalidCredentials(String),
30}
31
32// ---------------------------------------------------------------------------
33// Service account credentials
34// ---------------------------------------------------------------------------
35
36#[derive(Deserialize)]
37struct ServiceAccount {
38    client_email: String,
39    private_key: String,
40    #[serde(default = "default_token_uri")]
41    token_uri: String,
42}
43
44fn default_token_uri() -> String {
45    "https://oauth2.googleapis.com/token".into()
46}
47
48struct CachedToken {
49    token: String,
50    expires_at: u64,
51}
52
53// ---------------------------------------------------------------------------
54// GCS client
55// ---------------------------------------------------------------------------
56
57/// Minimal GCS client using the JSON API. Authenticates via service account JWT.
58pub struct GcsClient {
59    bucket: String,
60    http: reqwest::Client,
61    service_account: ServiceAccount,
62    token: Mutex<Option<CachedToken>>,
63}
64
65impl GcsClient {
66    /// Create a new GCS client from a bucket name and service account JSON string.
67    pub fn new(bucket: String, service_account_json: &str) -> Result<Self, GcsError> {
68        let sa: ServiceAccount = serde_json::from_str(service_account_json)
69            .map_err(|e| GcsError::InvalidCredentials(e.to_string()))?;
70
71        if sa.client_email.is_empty() || sa.private_key.is_empty() {
72            return Err(GcsError::InvalidCredentials(
73                "client_email and private_key are required".into(),
74            ));
75        }
76
77        let http = reqwest::Client::builder()
78            .timeout(std::time::Duration::from_secs(30))
79            .build()
80            .map_err(GcsError::Http)?;
81
82        Ok(Self {
83            bucket,
84            http,
85            service_account: sa,
86            token: Mutex::new(None),
87        })
88    }
89
90    /// Get a valid access token, refreshing if expired.
91    async fn access_token(&self) -> Result<String, GcsError> {
92        // Check cached token
93        {
94            let guard = self.token.lock().unwrap();
95            if let Some(ref cached) = *guard {
96                let now = std::time::SystemTime::now()
97                    .duration_since(std::time::UNIX_EPOCH)
98                    .unwrap()
99                    .as_secs();
100                if now < cached.expires_at {
101                    return Ok(cached.token.clone());
102                }
103            }
104        }
105
106        // Mint a new token via service account JWT → OAuth2 exchange
107        let now = std::time::SystemTime::now()
108            .duration_since(std::time::UNIX_EPOCH)
109            .unwrap()
110            .as_secs();
111
112        let claims = serde_json::json!({
113            "iss": self.service_account.client_email,
114            "scope": "https://www.googleapis.com/auth/devstorage.read_only",
115            "aud": self.service_account.token_uri,
116            "iat": now,
117            "exp": now + 3600,
118        });
119
120        let key =
121            jsonwebtoken::EncodingKey::from_rsa_pem(self.service_account.private_key.as_bytes())
122                .map_err(|e| GcsError::Auth(format!("invalid RSA key: {e}")))?;
123
124        let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256);
125        let assertion = jsonwebtoken::encode(&header, &claims, &key)?;
126
127        // Exchange JWT for access token
128        let resp = self
129            .http
130            .post(&self.service_account.token_uri)
131            .form(&[
132                ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
133                ("assertion", &assertion),
134            ])
135            .send()
136            .await?;
137
138        if !resp.status().is_success() {
139            let status = resp.status().as_u16();
140            let body = resp.text().await.unwrap_or_default();
141            return Err(GcsError::Api {
142                status,
143                message: body,
144            });
145        }
146
147        #[derive(Deserialize)]
148        struct TokenResponse {
149            access_token: String,
150            expires_in: Option<u64>,
151        }
152
153        let token_resp: TokenResponse = resp.json().await?;
154        let expires_at = now + token_resp.expires_in.unwrap_or(3600) - 300; // 5 min buffer
155
156        let access_token = token_resp.access_token.clone();
157
158        // Cache it
159        {
160            let mut guard = self.token.lock().unwrap();
161            *guard = Some(CachedToken {
162                token: token_resp.access_token,
163                expires_at,
164            });
165        }
166
167        Ok(access_token)
168    }
169
170    /// List top-level "directories" (prefixes) in the bucket.
171    /// Returns skill names like `["fal-generate", "compliance-screening", ...]`.
172    pub async fn list_skill_names(&self) -> Result<Vec<String>, GcsError> {
173        let url = format!(
174            "https://storage.googleapis.com/storage/v1/b/{}/o?delimiter=/",
175            self.bucket
176        );
177
178        let resp = self.get_with_retry(&url).await?;
179
180        if !resp.status().is_success() {
181            let status = resp.status().as_u16();
182            let body = resp.text().await.unwrap_or_default();
183            return Err(GcsError::Api {
184                status,
185                message: body,
186            });
187        }
188
189        #[derive(Deserialize)]
190        struct ListResponse {
191            #[serde(default)]
192            prefixes: Vec<String>,
193        }
194
195        let list: ListResponse = resp.json().await?;
196        Ok(list
197            .prefixes
198            .into_iter()
199            .map(|p| p.trim_end_matches('/').to_string())
200            .filter(|p| !p.is_empty())
201            .collect())
202    }
203
204    /// List all objects under a prefix (recursive).
205    /// Returns relative paths like `["SKILL.md", "skill.toml", "scripts/generate.sh"]`.
206    pub async fn list_objects(&self, prefix: &str) -> Result<Vec<String>, GcsError> {
207        let full_prefix = format!("{}/", prefix.trim_end_matches('/'));
208        let mut all_objects = Vec::new();
209        let mut page_token: Option<String> = None;
210
211        loop {
212            let mut url = format!(
213                "https://storage.googleapis.com/storage/v1/b/{}/o?prefix={}",
214                self.bucket,
215                urlencoded(&full_prefix)
216            );
217            if let Some(ref pt) = page_token {
218                url.push_str(&format!("&pageToken={}", urlencoded(pt)));
219            }
220
221            let resp = self.get_with_retry(&url).await?;
222
223            if !resp.status().is_success() {
224                let status = resp.status().as_u16();
225                let body = resp.text().await.unwrap_or_default();
226                return Err(GcsError::Api {
227                    status,
228                    message: body,
229                });
230            }
231
232            #[derive(Deserialize)]
233            struct ListResponse {
234                #[serde(default)]
235                items: Vec<ObjectItem>,
236                #[serde(rename = "nextPageToken")]
237                next_page_token: Option<String>,
238            }
239
240            #[derive(Deserialize)]
241            struct ObjectItem {
242                name: String,
243            }
244
245            let list: ListResponse = resp.json().await?;
246
247            for item in list.items {
248                // Strip the prefix to get relative path
249                if let Some(rel) = item.name.strip_prefix(&full_prefix) {
250                    if !rel.is_empty() {
251                        all_objects.push(rel.to_string());
252                    }
253                }
254            }
255
256            match list.next_page_token {
257                Some(pt) => page_token = Some(pt),
258                None => break,
259            }
260        }
261
262        Ok(all_objects)
263    }
264
265    /// Read a single object as bytes.
266    pub async fn get_object(&self, path: &str) -> Result<Vec<u8>, GcsError> {
267        let url = format!(
268            "https://storage.googleapis.com/storage/v1/b/{}/o/{}?alt=media",
269            self.bucket,
270            urlencoded(path)
271        );
272
273        let resp = self.get_with_retry(&url).await?;
274
275        if !resp.status().is_success() {
276            let status = resp.status().as_u16();
277            let body = resp.text().await.unwrap_or_default();
278            return Err(GcsError::Api {
279                status,
280                message: body,
281            });
282        }
283
284        Ok(resp.bytes().await?.to_vec())
285    }
286
287    /// Read a single object as UTF-8 text.
288    pub async fn get_object_text(&self, path: &str) -> Result<String, GcsError> {
289        let bytes = self.get_object(path).await?;
290        String::from_utf8(bytes).map_err(|e| GcsError::Utf8(e.to_string()))
291    }
292
293    /// Retry-aware wrapper for GET requests. Retries on 429/5xx up to 3 times with backoff.
294    async fn get_with_retry(&self, url: &str) -> Result<reqwest::Response, GcsError> {
295        let mut last_err = None;
296        for attempt in 0..3 {
297            let token = self.access_token().await?;
298            match self.http.get(url).bearer_auth(&token).send().await {
299                Ok(resp) => {
300                    let status = resp.status().as_u16();
301                    if status == 429 || status >= 500 {
302                        let body = resp.text().await.unwrap_or_default();
303                        last_err = Some(GcsError::Api {
304                            status,
305                            message: body,
306                        });
307                        let delay = std::time::Duration::from_millis(500 * (1 << attempt));
308                        tokio::time::sleep(delay).await;
309                        continue;
310                    }
311                    return Ok(resp);
312                }
313                Err(e) => {
314                    last_err = Some(GcsError::Http(e));
315                    let delay = std::time::Duration::from_millis(500 * (1 << attempt));
316                    tokio::time::sleep(delay).await;
317                }
318            }
319        }
320        Err(last_err.unwrap())
321    }
322}
323
324/// Minimal URL encoding for GCS object names.
325fn urlencoded(s: &str) -> String {
326    s.replace('%', "%25")
327        .replace(' ', "%20")
328        .replace('/', "%2F")
329        .replace('?', "%3F")
330        .replace('#', "%23")
331        .replace('&', "%26")
332        .replace('=', "%3D")
333}
334
335// ---------------------------------------------------------------------------
336// GCS skill source — loads all skills from a bucket into memory
337// ---------------------------------------------------------------------------
338
339/// Skills loaded from a GCS bucket, with all files cached in memory.
340pub struct GcsSkillSource {
341    /// Parsed skill metadata.
342    pub skills: Vec<SkillMeta>,
343    /// All files keyed by (skill_name, relative_path).
344    pub files: HashMap<(String, String), Vec<u8>>,
345}
346
347impl GcsSkillSource {
348    /// Load all skills from a GCS bucket concurrently.
349    ///
350    /// Enumerates top-level "directories" as skill names, then fetches
351    /// all files in each skill directory with bounded concurrency.
352    pub async fn load(client: &GcsClient) -> Result<Self, GcsError> {
353        use futures::stream::{self, StreamExt};
354
355        let skill_names = client.list_skill_names().await?;
356        tracing::debug!(count = skill_names.len(), "discovered skills in GCS bucket");
357
358        // Load all skills concurrently (up to 50 at a time)
359        let results: Vec<_> = stream::iter(skill_names)
360            .map(|name| async move { Self::load_one_skill(client, &name).await })
361            .buffer_unordered(50)
362            .collect()
363            .await;
364
365        let mut skills = Vec::new();
366        let mut files: HashMap<(String, String), Vec<u8>> = HashMap::new();
367
368        for result in results {
369            match result {
370                Ok((meta, skill_files)) => {
371                    skills.push(meta);
372                    files.extend(skill_files);
373                }
374                Err((name, e)) => {
375                    tracing::warn!(skill = %name, error = %e, "failed to load GCS skill");
376                }
377            }
378        }
379
380        Ok(GcsSkillSource { skills, files })
381    }
382
383    /// Load a single skill: list its files, fetch them concurrently, parse metadata.
384    async fn load_one_skill(
385        client: &GcsClient,
386        name: &str,
387    ) -> Result<(SkillMeta, Vec<((String, String), Vec<u8>)>), (String, String)> {
388        use futures::stream::{self, StreamExt};
389
390        let objects = client
391            .list_objects(name)
392            .await
393            .map_err(|e| (name.to_string(), e.to_string()))?;
394
395        // Fetch all files in this skill concurrently
396        let file_results: Vec<_> = stream::iter(objects)
397            .map(|rel_path| {
398                let full_path = format!("{}/{}", name, rel_path);
399                let name = name.to_string();
400                async move {
401                    match client.get_object(&full_path).await {
402                        Ok(data) => Some(((name, rel_path), data)),
403                        Err(e) => {
404                            tracing::warn!(path = %full_path, error = %e, "failed to fetch file");
405                            None
406                        }
407                    }
408                }
409            })
410            .buffer_unordered(20)
411            .collect()
412            .await;
413
414        let file_entries: Vec<((String, String), Vec<u8>)> =
415            file_results.into_iter().flatten().collect();
416
417        // Parse metadata
418        let skill_md = file_entries
419            .iter()
420            .find(|((_, p), _)| p == "SKILL.md")
421            .and_then(|(_, data)| std::str::from_utf8(data).ok())
422            .unwrap_or("");
423
424        let skill_toml = file_entries
425            .iter()
426            .find(|((_, p), _)| p == "skill.toml")
427            .and_then(|(_, data)| std::str::from_utf8(data).ok());
428
429        let meta = skill::parse_skill_metadata(name, skill_md, skill_toml)
430            .map_err(|e| (name.to_string(), e.to_string()))?;
431
432        Ok((meta, file_entries))
433    }
434
435    /// Number of skills loaded.
436    pub fn skill_count(&self) -> usize {
437        self.skills.len()
438    }
439}