1use serde::Deserialize;
8use std::collections::HashMap;
9use std::sync::Mutex;
10use thiserror::Error;
11
12use crate::core::skill::{self, SkillMeta};
13
14#[derive(Error, Debug)]
15pub enum GcsError {
16 #[error("HTTP error: {0}")]
17 Http(#[from] reqwest::Error),
18 #[error("Auth error: {0}")]
19 Auth(String),
20 #[error("JSON parse error: {0}")]
21 Json(#[from] serde_json::Error),
22 #[error("JWT signing error: {0}")]
23 Jwt(#[from] jsonwebtoken::errors::Error),
24 #[error("GCS API error ({status}): {message}")]
25 Api { status: u16, message: String },
26 #[error("Invalid UTF-8 in GCS object: {0}")]
27 Utf8(String),
28 #[error("Invalid service account JSON: {0}")]
29 InvalidCredentials(String),
30}
31
32#[derive(Deserialize)]
37struct ServiceAccount {
38 client_email: String,
39 private_key: String,
40 #[serde(default = "default_token_uri")]
41 token_uri: String,
42}
43
44fn default_token_uri() -> String {
45 "https://oauth2.googleapis.com/token".into()
46}
47
48struct CachedToken {
49 token: String,
50 expires_at: u64,
51}
52
53pub struct GcsClient {
59 bucket: String,
60 http: reqwest::Client,
61 service_account: ServiceAccount,
62 token: Mutex<Option<CachedToken>>,
63}
64
65impl GcsClient {
66 pub fn new(bucket: String, service_account_json: &str) -> Result<Self, GcsError> {
68 let sa: ServiceAccount = serde_json::from_str(service_account_json)
69 .map_err(|e| GcsError::InvalidCredentials(e.to_string()))?;
70
71 if sa.client_email.is_empty() || sa.private_key.is_empty() {
72 return Err(GcsError::InvalidCredentials(
73 "client_email and private_key are required".into(),
74 ));
75 }
76
77 let http = reqwest::Client::builder()
78 .timeout(std::time::Duration::from_secs(30))
79 .build()
80 .map_err(GcsError::Http)?;
81
82 Ok(Self {
83 bucket,
84 http,
85 service_account: sa,
86 token: Mutex::new(None),
87 })
88 }
89
90 async fn access_token(&self) -> Result<String, GcsError> {
92 {
94 let guard = self.token.lock().unwrap();
95 if let Some(ref cached) = *guard {
96 let now = std::time::SystemTime::now()
97 .duration_since(std::time::UNIX_EPOCH)
98 .unwrap()
99 .as_secs();
100 if now < cached.expires_at {
101 return Ok(cached.token.clone());
102 }
103 }
104 }
105
106 let now = std::time::SystemTime::now()
108 .duration_since(std::time::UNIX_EPOCH)
109 .unwrap()
110 .as_secs();
111
112 let claims = serde_json::json!({
113 "iss": self.service_account.client_email,
114 "scope": "https://www.googleapis.com/auth/devstorage.read_only",
115 "aud": self.service_account.token_uri,
116 "iat": now,
117 "exp": now + 3600,
118 });
119
120 let key =
121 jsonwebtoken::EncodingKey::from_rsa_pem(self.service_account.private_key.as_bytes())
122 .map_err(|e| GcsError::Auth(format!("invalid RSA key: {e}")))?;
123
124 let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256);
125 let assertion = jsonwebtoken::encode(&header, &claims, &key)?;
126
127 let resp = self
129 .http
130 .post(&self.service_account.token_uri)
131 .form(&[
132 ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
133 ("assertion", &assertion),
134 ])
135 .send()
136 .await?;
137
138 if !resp.status().is_success() {
139 let status = resp.status().as_u16();
140 let body = resp.text().await.unwrap_or_default();
141 return Err(GcsError::Api {
142 status,
143 message: body,
144 });
145 }
146
147 #[derive(Deserialize)]
148 struct TokenResponse {
149 access_token: String,
150 expires_in: Option<u64>,
151 }
152
153 let token_resp: TokenResponse = resp.json().await?;
154 let expires_at = now + token_resp.expires_in.unwrap_or(3600) - 300; let access_token = token_resp.access_token.clone();
157
158 {
160 let mut guard = self.token.lock().unwrap();
161 *guard = Some(CachedToken {
162 token: token_resp.access_token,
163 expires_at,
164 });
165 }
166
167 Ok(access_token)
168 }
169
170 pub async fn list_skill_names(&self) -> Result<Vec<String>, GcsError> {
173 let url = format!(
174 "https://storage.googleapis.com/storage/v1/b/{}/o?delimiter=/",
175 self.bucket
176 );
177
178 let resp = self.get_with_retry(&url).await?;
179
180 if !resp.status().is_success() {
181 let status = resp.status().as_u16();
182 let body = resp.text().await.unwrap_or_default();
183 return Err(GcsError::Api {
184 status,
185 message: body,
186 });
187 }
188
189 #[derive(Deserialize)]
190 struct ListResponse {
191 #[serde(default)]
192 prefixes: Vec<String>,
193 }
194
195 let list: ListResponse = resp.json().await?;
196 Ok(list
197 .prefixes
198 .into_iter()
199 .map(|p| p.trim_end_matches('/').to_string())
200 .filter(|p| !p.is_empty())
201 .collect())
202 }
203
204 pub async fn list_objects(&self, prefix: &str) -> Result<Vec<String>, GcsError> {
207 let full_prefix = format!("{}/", prefix.trim_end_matches('/'));
208 let mut all_objects = Vec::new();
209 let mut page_token: Option<String> = None;
210
211 loop {
212 let mut url = format!(
213 "https://storage.googleapis.com/storage/v1/b/{}/o?prefix={}",
214 self.bucket,
215 urlencoded(&full_prefix)
216 );
217 if let Some(ref pt) = page_token {
218 url.push_str(&format!("&pageToken={}", urlencoded(pt)));
219 }
220
221 let resp = self.get_with_retry(&url).await?;
222
223 if !resp.status().is_success() {
224 let status = resp.status().as_u16();
225 let body = resp.text().await.unwrap_or_default();
226 return Err(GcsError::Api {
227 status,
228 message: body,
229 });
230 }
231
232 #[derive(Deserialize)]
233 struct ListResponse {
234 #[serde(default)]
235 items: Vec<ObjectItem>,
236 #[serde(rename = "nextPageToken")]
237 next_page_token: Option<String>,
238 }
239
240 #[derive(Deserialize)]
241 struct ObjectItem {
242 name: String,
243 }
244
245 let list: ListResponse = resp.json().await?;
246
247 for item in list.items {
248 if let Some(rel) = item.name.strip_prefix(&full_prefix) {
250 if !rel.is_empty() {
251 all_objects.push(rel.to_string());
252 }
253 }
254 }
255
256 match list.next_page_token {
257 Some(pt) => page_token = Some(pt),
258 None => break,
259 }
260 }
261
262 Ok(all_objects)
263 }
264
265 pub async fn get_object(&self, path: &str) -> Result<Vec<u8>, GcsError> {
267 let url = format!(
268 "https://storage.googleapis.com/storage/v1/b/{}/o/{}?alt=media",
269 self.bucket,
270 urlencoded(path)
271 );
272
273 let resp = self.get_with_retry(&url).await?;
274
275 if !resp.status().is_success() {
276 let status = resp.status().as_u16();
277 let body = resp.text().await.unwrap_or_default();
278 return Err(GcsError::Api {
279 status,
280 message: body,
281 });
282 }
283
284 Ok(resp.bytes().await?.to_vec())
285 }
286
287 pub async fn get_object_text(&self, path: &str) -> Result<String, GcsError> {
289 let bytes = self.get_object(path).await?;
290 String::from_utf8(bytes).map_err(|e| GcsError::Utf8(e.to_string()))
291 }
292
293 async fn get_with_retry(&self, url: &str) -> Result<reqwest::Response, GcsError> {
295 let mut last_err = None;
296 for attempt in 0..3 {
297 let token = self.access_token().await?;
298 match self.http.get(url).bearer_auth(&token).send().await {
299 Ok(resp) => {
300 let status = resp.status().as_u16();
301 if status == 429 || status >= 500 {
302 let body = resp.text().await.unwrap_or_default();
303 last_err = Some(GcsError::Api {
304 status,
305 message: body,
306 });
307 let delay = std::time::Duration::from_millis(500 * (1 << attempt));
308 tokio::time::sleep(delay).await;
309 continue;
310 }
311 return Ok(resp);
312 }
313 Err(e) => {
314 last_err = Some(GcsError::Http(e));
315 let delay = std::time::Duration::from_millis(500 * (1 << attempt));
316 tokio::time::sleep(delay).await;
317 }
318 }
319 }
320 Err(last_err.unwrap())
321 }
322}
323
324fn urlencoded(s: &str) -> String {
326 s.replace('%', "%25")
327 .replace(' ', "%20")
328 .replace('/', "%2F")
329 .replace('?', "%3F")
330 .replace('#', "%23")
331 .replace('&', "%26")
332 .replace('=', "%3D")
333}
334
335pub struct GcsSkillSource {
341 pub skills: Vec<SkillMeta>,
343 pub files: HashMap<(String, String), Vec<u8>>,
345}
346
347impl GcsSkillSource {
348 pub async fn load(client: &GcsClient) -> Result<Self, GcsError> {
353 use futures::stream::{self, StreamExt};
354
355 let skill_names = client.list_skill_names().await?;
356 tracing::debug!(count = skill_names.len(), "discovered skills in GCS bucket");
357
358 let results: Vec<_> = stream::iter(skill_names)
360 .map(|name| async move { Self::load_one_skill(client, &name).await })
361 .buffer_unordered(50)
362 .collect()
363 .await;
364
365 let mut skills = Vec::new();
366 let mut files: HashMap<(String, String), Vec<u8>> = HashMap::new();
367
368 for result in results {
369 match result {
370 Ok((meta, skill_files)) => {
371 skills.push(meta);
372 files.extend(skill_files);
373 }
374 Err((name, e)) => {
375 tracing::warn!(skill = %name, error = %e, "failed to load GCS skill");
376 }
377 }
378 }
379
380 Ok(GcsSkillSource { skills, files })
381 }
382
383 async fn load_one_skill(
385 client: &GcsClient,
386 name: &str,
387 ) -> Result<(SkillMeta, Vec<((String, String), Vec<u8>)>), (String, String)> {
388 use futures::stream::{self, StreamExt};
389
390 let objects = client
391 .list_objects(name)
392 .await
393 .map_err(|e| (name.to_string(), e.to_string()))?;
394
395 let file_results: Vec<_> = stream::iter(objects)
397 .map(|rel_path| {
398 let full_path = format!("{}/{}", name, rel_path);
399 let name = name.to_string();
400 async move {
401 match client.get_object(&full_path).await {
402 Ok(data) => Some(((name, rel_path), data)),
403 Err(e) => {
404 tracing::warn!(path = %full_path, error = %e, "failed to fetch file");
405 None
406 }
407 }
408 }
409 })
410 .buffer_unordered(20)
411 .collect()
412 .await;
413
414 let file_entries: Vec<((String, String), Vec<u8>)> =
415 file_results.into_iter().flatten().collect();
416
417 let skill_md = file_entries
419 .iter()
420 .find(|((_, p), _)| p == "SKILL.md")
421 .and_then(|(_, data)| std::str::from_utf8(data).ok())
422 .unwrap_or("");
423
424 let skill_toml = file_entries
425 .iter()
426 .find(|((_, p), _)| p == "skill.toml")
427 .and_then(|(_, data)| std::str::from_utf8(data).ok());
428
429 let meta = skill::parse_skill_metadata(name, skill_md, skill_toml)
430 .map_err(|e| (name.to_string(), e.to_string()))?;
431
432 Ok((meta, file_entries))
433 }
434
435 pub fn skill_count(&self) -> usize {
437 self.skills.len()
438 }
439}