1use crate::crd::CRD;
4use anyhow::Result;
5use futures::StreamExt;
6use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::Mutex;
10
11pub struct CRDFetcher {
12 client: reqwest::Client,
13 multi_progress: Arc<MultiProgress>,
14}
15
16impl CRDFetcher {
17 pub fn new() -> Result<Self> {
18 Ok(Self {
19 client: reqwest::Client::builder()
20 .timeout(Duration::from_secs(30))
21 .user_agent("amalgam")
22 .build()?,
23 multi_progress: Arc::new(MultiProgress::new()),
24 })
25 }
26
27 pub async fn fetch_from_url(&self, url: &str) -> Result<Vec<CRD>> {
33 let is_tty = atty::is(atty::Stream::Stdout);
34
35 let main_spinner = if is_tty {
36 let pb = self.multi_progress.add(ProgressBar::new_spinner());
37 pb.set_style(
38 ProgressStyle::default_spinner()
39 .template("{spinner:.cyan} {msg}")?
40 .tick_strings(&["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]),
41 );
42 pb.enable_steady_tick(Duration::from_millis(100));
43 pb.set_message("Initializing CRD fetcher...");
44 Some(pb)
45 } else {
46 None
47 };
48
49 let result = if url.contains("github.com") {
50 self.fetch_from_github(url, is_tty).await
51 } else if url.ends_with(".yaml") || url.ends_with(".yml") {
52 if let Some(ref pb) = main_spinner {
54 pb.set_message("Downloading YAML file...".to_string());
55 } else {
56 println!("Downloading YAML file from {}", url);
57 }
58 let content = self.client.get(url).send().await?.text().await?;
59 let crd: CRD = serde_yaml::from_str(&content)?;
60 Ok(vec![crd])
61 } else {
62 self.fetch_directory(url).await
64 };
65
66 if let Some(pb) = main_spinner {
67 if let Ok(ref crds) = result {
68 pb.finish_with_message(format!("✓ Successfully fetched {} CRDs", crds.len()));
69 } else {
70 pb.finish_with_message("✗ Failed to fetch CRDs");
71 }
72 } else if let Ok(ref crds) = result {
73 println!("Successfully fetched {} CRDs", crds.len());
74 }
75
76 result
77 }
78
79 async fn fetch_from_github(&self, url: &str, is_tty: bool) -> Result<Vec<CRD>> {
81 let parts: Vec<&str> = url.split('/').collect();
83 if parts.len() < 5 {
84 return Err(anyhow::anyhow!("Invalid GitHub URL"));
85 }
86
87 let owner = parts[3];
88 let repo = parts[4];
89
90 let (path, branch) = if let Some(tree_idx) = parts.iter().position(|&p| p == "tree") {
92 if parts.len() > tree_idx + 2 {
93 let branch = parts[tree_idx + 1];
94 let path = parts[tree_idx + 2..].join("/");
95 (path, branch)
96 } else if parts.len() > tree_idx + 1 {
97 let branch = parts[tree_idx + 1];
98 (String::new(), branch)
99 } else {
100 (String::new(), "main")
101 }
102 } else if let Some(blob_idx) = parts.iter().position(|&p| p == "blob") {
103 if parts.len() > blob_idx + 2 {
105 let branch = parts[blob_idx + 1];
106 let file_path = parts[blob_idx + 2..].join("/");
107 let raw_url = format!(
108 "https://raw.githubusercontent.com/{}/{}/{}/{}",
109 owner, repo, branch, file_path
110 );
111
112 let pb = if is_tty {
113 let pb = self.multi_progress.add(ProgressBar::new_spinner());
114 pb.set_style(
115 ProgressStyle::default_spinner().template("{spinner:.cyan} {msg}")?,
116 );
117 pb.enable_steady_tick(Duration::from_millis(100));
118 pb.set_message(format!("Downloading {}", file_path));
119 Some(pb)
120 } else {
121 println!("Downloading {}", file_path);
122 None
123 };
124
125 let content = self.client.get(&raw_url).send().await?.text().await?;
126 let crd: CRD = serde_yaml::from_str(&content)?;
127
128 if let Some(pb) = pb {
129 pb.finish_with_message(format!("✓ Downloaded {}", file_path));
130 }
131
132 return Ok(vec![crd]);
133 }
134 (String::new(), "main")
135 } else {
136 (String::new(), "main")
137 };
138
139 let api_url = format!(
141 "https://api.github.com/repos/{}/{}/contents/{}?ref={}",
142 owner, repo, path, branch
143 );
144
145 let listing_pb = if is_tty {
146 let pb = self.multi_progress.add(ProgressBar::new_spinner());
147 pb.set_style(ProgressStyle::default_spinner().template("{spinner:.cyan} {msg}")?);
148 pb.enable_steady_tick(Duration::from_millis(100));
149 pb.set_message(format!("Listing files from {}/{}/{}", owner, repo, path));
150 Some(pb)
151 } else {
152 println!("Listing files from {}/{}/{}", owner, repo, path);
153 None
154 };
155
156 let response = self
157 .client
158 .get(&api_url)
159 .header("Accept", "application/vnd.github.v3+json")
160 .send()
161 .await?;
162
163 if !response.status().is_success() {
164 let status = response.status();
165 let text = response.text().await?;
166 return Err(anyhow::anyhow!("GitHub API error ({}): {}", status, text));
167 }
168
169 let files: Vec<GitHubContent> = response.json().await?;
170
171 let yaml_files: Vec<_> = files
173 .iter()
174 .filter(|item| item.name.ends_with(".yaml") || item.name.ends_with(".yml"))
175 .collect();
176
177 if let Some(pb) = listing_pb {
178 pb.finish_with_message(format!("✓ Found {} YAML files", yaml_files.len()));
179 } else {
180 println!("Found {} YAML files", yaml_files.len());
181 }
182
183 if yaml_files.is_empty() {
184 return Ok(Vec::new());
185 }
186
187 let main_progress = if is_tty {
189 let pb = self
190 .multi_progress
191 .add(ProgressBar::new(yaml_files.len() as u64));
192 pb.set_style(
193 ProgressStyle::default_bar()
194 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} {msg}")?
195 .progress_chars("##-")
196 );
197 pb.set_message("Overall progress");
198 Some(Arc::new(pb))
199 } else {
200 None
201 };
202
203 let max_concurrent = 5;
205 let semaphore = Arc::new(tokio::sync::Semaphore::new(max_concurrent));
206 let client = self.client.clone();
207 let multi_progress = self.multi_progress.clone();
208 let active_downloads = Arc::new(Mutex::new(Vec::new()));
209
210 let total_files = yaml_files.len();
211 let download_tasks = yaml_files.iter().enumerate().map(|(idx, item)| {
212 let client = client.clone();
213 let semaphore = semaphore.clone();
214 let name = item.name.clone();
215 let download_url = item.download_url.clone();
216 let main_progress = main_progress.clone();
217 let multi_progress = multi_progress.clone();
218 let active_downloads = active_downloads.clone();
219
220 async move {
221 let _permit = semaphore.acquire().await.unwrap();
222
223 let individual_pb = if is_tty {
225 let pb = multi_progress.add(ProgressBar::new_spinner());
226 pb.set_style(
227 ProgressStyle::default_spinner()
228 .template(&format!(" {{spinner:.yellow}} [{}] {{msg}}", idx + 1))
229 .unwrap()
230 .tick_strings(&["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]),
231 );
232 pb.enable_steady_tick(Duration::from_millis(80));
233 pb.set_message(format!("Downloading {}", name));
234
235 active_downloads.lock().await.push(pb.clone());
237
238 Some(pb)
239 } else {
240 println!("[{}/{}] Downloading {}", idx + 1, total_files, name);
241 None
242 };
243
244 let result = if let Some(url) = download_url {
245 match fetch_single_crd(&client, &url).await {
246 Ok(crd) => {
247 if let Some(ref pb) = individual_pb {
248 pb.finish_with_message(format!("✓ {}", name));
249 }
250 Some(crd)
251 }
252 Err(e) => {
253 if let Some(ref pb) = individual_pb {
254 pb.finish_with_message(format!("✗ {} ({})", name, e));
255 } else {
256 eprintln!("Failed to parse {}: {}", name, e);
257 }
258 None
259 }
260 }
261 } else {
262 if let Some(ref pb) = individual_pb {
263 pb.finish_with_message(format!("✗ {} (no download URL)", name));
264 }
265 None
266 };
267
268 if let Some(ref main_pb) = main_progress {
270 main_pb.inc(1);
271 let completed = main_pb.position();
272 let total = main_pb.length().unwrap_or(0);
273 main_pb.set_message(format!("Completed {}/{} files", completed, total));
274 }
275
276 if let Some(ref pb) = individual_pb {
278 let mut active = active_downloads.lock().await;
279 active.retain(|p| !Arc::ptr_eq(&Arc::new(p.clone()), &Arc::new(pb.clone())));
280 }
281
282 result
283 }
284 });
285
286 let mut stream = futures::stream::iter(download_tasks).buffer_unordered(max_concurrent);
287
288 let mut crds = Vec::new();
289 while let Some(result) = stream.next().await {
290 if let Some(crd) = result {
291 crds.push(crd);
292 }
293 }
294
295 if let Some(ref main_pb) = main_progress {
296 main_pb.finish_with_message(format!(
297 "✓ Successfully downloaded {} valid CRDs",
298 crds.len()
299 ));
300 } else {
301 println!("Downloaded {} valid CRDs", crds.len());
302 }
303
304 Ok(crds)
305 }
306
307 async fn fetch_directory(&self, _url: &str) -> Result<Vec<CRD>> {
308 Err(anyhow::anyhow!(
311 "Directory listing not supported for non-GitHub URLs"
312 ))
313 }
314
315 pub fn finish(&self) {
317 self.multi_progress.clear().ok();
318 }
319}
320
321async fn fetch_single_crd(client: &reqwest::Client, url: &str) -> Result<CRD> {
322 let content = client.get(url).send().await?.text().await?;
323
324 if let Ok(crd) = serde_yaml::from_str::<CRD>(&content) {
326 return Ok(crd);
327 }
328
329 let value: serde_yaml::Value = serde_yaml::from_str(&content)?;
331 if value.get("kind")
332 == Some(&serde_yaml::Value::String(
333 "CustomResourceDefinition".to_string(),
334 ))
335 {
336 let crd: CRD = serde_yaml::from_value(value)?;
337 return Ok(crd);
338 }
339
340 Err(anyhow::anyhow!("Not a valid CRD"))
341}
342
343#[derive(Debug, serde::Deserialize)]
344struct GitHubContent {
345 name: String,
346 #[allow(dead_code)]
347 path: String,
348 #[serde(rename = "type")]
349 #[allow(dead_code)]
350 content_type: String,
351 download_url: Option<String>,
352}
353
354impl Default for CRDFetcher {
355 fn default() -> Self {
356 Self::new().unwrap()
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363 use pretty_assertions::assert_eq;
364 use serde_json::json;
365 use wiremock::{
366 matchers::{method, path},
367 Mock, MockServer, ResponseTemplate,
368 };
369
370 fn sample_crd() -> serde_json::Value {
371 json!({
372 "apiVersion": "apiextensions.k8s.io/v1",
373 "kind": "CustomResourceDefinition",
374 "metadata": {
375 "name": "compositions.apiextensions.crossplane.io"
376 },
377 "spec": {
378 "group": "apiextensions.crossplane.io",
379 "names": {
380 "kind": "Composition",
381 "plural": "compositions",
382 "singular": "composition"
383 },
384 "versions": [{
385 "name": "v1",
386 "served": true,
387 "storage": true,
388 "schema": {
389 "openAPIV3Schema": {
390 "type": "object",
391 "properties": {
392 "spec": {
393 "type": "object",
394 "properties": {
395 "compositeTypeRef": {
396 "type": "object",
397 "properties": {
398 "apiVersion": {"type": "string"},
399 "kind": {"type": "string"}
400 }
401 }
402 }
403 }
404 }
405 }
406 }
407 }]
408 }
409 })
410 }
411
412 #[tokio::test]
413 async fn test_fetch_single_yaml_file() {
414 let mock_server = MockServer::start().await;
415
416 let crd_yaml = serde_yaml::to_string(&sample_crd()).unwrap();
417
418 Mock::given(method("GET"))
419 .and(path("/test.yaml"))
420 .respond_with(ResponseTemplate::new(200).set_body_string(crd_yaml))
421 .mount(&mock_server)
422 .await;
423
424 let fetcher = CRDFetcher::new().unwrap();
425 let url = format!("{}/test.yaml", &mock_server.uri());
426 let crds = fetcher.fetch_from_url(&url).await.unwrap();
427
428 assert_eq!(crds.len(), 1);
429 assert_eq!(crds[0].spec.group, "apiextensions.crossplane.io");
430 assert_eq!(crds[0].spec.names.kind, "Composition");
431 }
432
433 #[tokio::test]
434 async fn test_error_handling_404() {
435 let mock_server = MockServer::start().await;
436
437 Mock::given(method("GET"))
438 .and(path("/missing.yaml"))
439 .respond_with(ResponseTemplate::new(404))
440 .mount(&mock_server)
441 .await;
442
443 let fetcher = CRDFetcher::new().unwrap();
444 let url = format!("{}/missing.yaml", &mock_server.uri());
445 let result = fetcher.fetch_from_url(&url).await;
446
447 assert!(result.is_err());
448 }
449
450 #[tokio::test]
451 async fn test_error_handling_invalid_yaml() {
452 let mock_server = MockServer::start().await;
453
454 Mock::given(method("GET"))
455 .and(path("/invalid.yaml"))
456 .respond_with(ResponseTemplate::new(200).set_body_string("not: valid: yaml: content:"))
457 .mount(&mock_server)
458 .await;
459
460 let fetcher = CRDFetcher::new().unwrap();
461 let url = format!("{}/invalid.yaml", &mock_server.uri());
462 let result = fetcher.fetch_from_url(&url).await;
463
464 assert!(result.is_err());
465 }
466}