amalgam_parser/
fetch.rs

1//! Fetch CRDs from URLs, GitHub repos, etc.
2
3use crate::crd::CRD;
4use anyhow::Result;
5use futures::StreamExt;
6use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::Mutex;
10
11pub struct CRDFetcher {
12    client: reqwest::Client,
13    multi_progress: Arc<MultiProgress>,
14}
15
16impl CRDFetcher {
17    pub fn new() -> Result<Self> {
18        Ok(Self {
19            client: reqwest::Client::builder()
20                .timeout(Duration::from_secs(30))
21                .user_agent("amalgam")
22                .build()?,
23            multi_progress: Arc::new(MultiProgress::new()),
24        })
25    }
26
27    /// Fetch CRDs from a URL
28    /// Supports:
29    /// - Direct YAML files
30    /// - GitHub repository URLs
31    /// - GitHub directory listings
32    pub async fn fetch_from_url(&self, url: &str) -> Result<Vec<CRD>> {
33        let is_tty = atty::is(atty::Stream::Stdout);
34
35        let main_spinner = if is_tty {
36            let pb = self.multi_progress.add(ProgressBar::new_spinner());
37            pb.set_style(
38                ProgressStyle::default_spinner()
39                    .template("{spinner:.cyan} {msg}")?
40                    .tick_strings(&["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]),
41            );
42            pb.enable_steady_tick(Duration::from_millis(100));
43            pb.set_message("Initializing CRD fetcher...");
44            Some(pb)
45        } else {
46            None
47        };
48
49        let result = if url.contains("github.com") {
50            self.fetch_from_github(url, is_tty).await
51        } else if url.ends_with(".yaml") || url.ends_with(".yml") {
52            // Direct YAML file
53            if let Some(ref pb) = main_spinner {
54                pb.set_message("Downloading YAML file...".to_string());
55            } else {
56                println!("Downloading YAML file from {}", url);
57            }
58            let content = self.client.get(url).send().await?.text().await?;
59            let crd: CRD = serde_yaml::from_str(&content)?;
60            Ok(vec![crd])
61        } else {
62            // Try to fetch as directory listing
63            self.fetch_directory(url).await
64        };
65
66        if let Some(pb) = main_spinner {
67            if let Ok(ref crds) = result {
68                pb.finish_with_message(format!("✓ Successfully fetched {} CRDs", crds.len()));
69            } else {
70                pb.finish_with_message("✗ Failed to fetch CRDs");
71            }
72        } else if let Ok(ref crds) = result {
73            println!("Successfully fetched {} CRDs", crds.len());
74        }
75
76        result
77    }
78
79    /// Fetch CRDs from a GitHub repository or directory
80    async fn fetch_from_github(&self, url: &str, is_tty: bool) -> Result<Vec<CRD>> {
81        // Convert GitHub URL to raw content URL
82        let parts: Vec<&str> = url.split('/').collect();
83        if parts.len() < 5 {
84            return Err(anyhow::anyhow!("Invalid GitHub URL"));
85        }
86
87        let owner = parts[3];
88        let repo = parts[4];
89
90        // Find the path after tree/branch
91        let (path, branch) = if let Some(tree_idx) = parts.iter().position(|&p| p == "tree") {
92            if parts.len() > tree_idx + 2 {
93                let branch = parts[tree_idx + 1];
94                let path = parts[tree_idx + 2..].join("/");
95                (path, branch)
96            } else if parts.len() > tree_idx + 1 {
97                let branch = parts[tree_idx + 1];
98                (String::new(), branch)
99            } else {
100                (String::new(), "main")
101            }
102        } else if let Some(blob_idx) = parts.iter().position(|&p| p == "blob") {
103            // Single file
104            if parts.len() > blob_idx + 2 {
105                let branch = parts[blob_idx + 1];
106                let file_path = parts[blob_idx + 2..].join("/");
107                let raw_url = format!(
108                    "https://raw.githubusercontent.com/{}/{}/{}/{}",
109                    owner, repo, branch, file_path
110                );
111
112                let pb = if is_tty {
113                    let pb = self.multi_progress.add(ProgressBar::new_spinner());
114                    pb.set_style(
115                        ProgressStyle::default_spinner().template("{spinner:.cyan} {msg}")?,
116                    );
117                    pb.enable_steady_tick(Duration::from_millis(100));
118                    pb.set_message(format!("Downloading {}", file_path));
119                    Some(pb)
120                } else {
121                    println!("Downloading {}", file_path);
122                    None
123                };
124
125                let content = self.client.get(&raw_url).send().await?.text().await?;
126                let crd: CRD = serde_yaml::from_str(&content)?;
127
128                if let Some(pb) = pb {
129                    pb.finish_with_message(format!("✓ Downloaded {}", file_path));
130                }
131
132                return Ok(vec![crd]);
133            }
134            (String::new(), "main")
135        } else {
136            (String::new(), "main")
137        };
138
139        // Use GitHub API to list directory contents
140        let api_url = format!(
141            "https://api.github.com/repos/{}/{}/contents/{}?ref={}",
142            owner, repo, path, branch
143        );
144
145        let listing_pb = if is_tty {
146            let pb = self.multi_progress.add(ProgressBar::new_spinner());
147            pb.set_style(ProgressStyle::default_spinner().template("{spinner:.cyan} {msg}")?);
148            pb.enable_steady_tick(Duration::from_millis(100));
149            pb.set_message(format!("Listing files from {}/{}/{}", owner, repo, path));
150            Some(pb)
151        } else {
152            println!("Listing files from {}/{}/{}", owner, repo, path);
153            None
154        };
155
156        let response = self
157            .client
158            .get(&api_url)
159            .header("Accept", "application/vnd.github.v3+json")
160            .send()
161            .await?;
162
163        if !response.status().is_success() {
164            let status = response.status();
165            let text = response.text().await?;
166            return Err(anyhow::anyhow!("GitHub API error ({}): {}", status, text));
167        }
168
169        let files: Vec<GitHubContent> = response.json().await?;
170
171        // Filter for YAML files that look like CRDs
172        let yaml_files: Vec<_> = files
173            .iter()
174            .filter(|item| item.name.ends_with(".yaml") || item.name.ends_with(".yml"))
175            .collect();
176
177        if let Some(pb) = listing_pb {
178            pb.finish_with_message(format!("✓ Found {} YAML files", yaml_files.len()));
179        } else {
180            println!("Found {} YAML files", yaml_files.len());
181        }
182
183        if yaml_files.is_empty() {
184            return Ok(Vec::new());
185        }
186
187        // Create main progress bar for overall download progress
188        let main_progress = if is_tty {
189            let pb = self
190                .multi_progress
191                .add(ProgressBar::new(yaml_files.len() as u64));
192            pb.set_style(
193                ProgressStyle::default_bar()
194                    .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} {msg}")?
195                    .progress_chars("##-")
196            );
197            pb.set_message("Overall progress");
198            Some(Arc::new(pb))
199        } else {
200            None
201        };
202
203        // Download files concurrently with controlled parallelism
204        let max_concurrent = 5;
205        let semaphore = Arc::new(tokio::sync::Semaphore::new(max_concurrent));
206        let client = self.client.clone();
207        let multi_progress = self.multi_progress.clone();
208        let active_downloads = Arc::new(Mutex::new(Vec::new()));
209
210        let total_files = yaml_files.len();
211        let download_tasks = yaml_files.iter().enumerate().map(|(idx, item)| {
212            let client = client.clone();
213            let semaphore = semaphore.clone();
214            let name = item.name.clone();
215            let download_url = item.download_url.clone();
216            let main_progress = main_progress.clone();
217            let multi_progress = multi_progress.clone();
218            let active_downloads = active_downloads.clone();
219
220            async move {
221                let _permit = semaphore.acquire().await.unwrap();
222
223                // Create individual progress bar for this download
224                let individual_pb = if is_tty {
225                    let pb = multi_progress.add(ProgressBar::new_spinner());
226                    pb.set_style(
227                        ProgressStyle::default_spinner()
228                            .template(&format!("  {{spinner:.yellow}} [{}] {{msg}}", idx + 1))
229                            .unwrap()
230                            .tick_strings(&["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]),
231                    );
232                    pb.enable_steady_tick(Duration::from_millis(80));
233                    pb.set_message(format!("Downloading {}", name));
234
235                    // Track active download
236                    active_downloads.lock().await.push(pb.clone());
237
238                    Some(pb)
239                } else {
240                    println!("[{}/{}] Downloading {}", idx + 1, total_files, name);
241                    None
242                };
243
244                let result = if let Some(url) = download_url {
245                    match fetch_single_crd(&client, &url).await {
246                        Ok(crd) => {
247                            if let Some(ref pb) = individual_pb {
248                                pb.finish_with_message(format!("✓ {}", name));
249                            }
250                            Some(crd)
251                        }
252                        Err(e) => {
253                            if let Some(ref pb) = individual_pb {
254                                pb.finish_with_message(format!("✗ {} ({})", name, e));
255                            } else {
256                                eprintln!("Failed to parse {}: {}", name, e);
257                            }
258                            None
259                        }
260                    }
261                } else {
262                    if let Some(ref pb) = individual_pb {
263                        pb.finish_with_message(format!("✗ {} (no download URL)", name));
264                    }
265                    None
266                };
267
268                // Update main progress
269                if let Some(ref main_pb) = main_progress {
270                    main_pb.inc(1);
271                    let completed = main_pb.position();
272                    let total = main_pb.length().unwrap_or(0);
273                    main_pb.set_message(format!("Completed {}/{} files", completed, total));
274                }
275
276                // Remove from active downloads
277                if let Some(ref pb) = individual_pb {
278                    let mut active = active_downloads.lock().await;
279                    active.retain(|p| !Arc::ptr_eq(&Arc::new(p.clone()), &Arc::new(pb.clone())));
280                }
281
282                result
283            }
284        });
285
286        let mut stream = futures::stream::iter(download_tasks).buffer_unordered(max_concurrent);
287
288        let mut crds = Vec::new();
289        while let Some(result) = stream.next().await {
290            if let Some(crd) = result {
291                crds.push(crd);
292            }
293        }
294
295        if let Some(ref main_pb) = main_progress {
296            main_pb.finish_with_message(format!(
297                "✓ Successfully downloaded {} valid CRDs",
298                crds.len()
299            ));
300        } else {
301            println!("Downloaded {} valid CRDs", crds.len());
302        }
303
304        Ok(crds)
305    }
306
307    async fn fetch_directory(&self, _url: &str) -> Result<Vec<CRD>> {
308        // For now, just try to list files
309        // In a real implementation, would need directory listing support
310        Err(anyhow::anyhow!(
311            "Directory listing not supported for non-GitHub URLs"
312        ))
313    }
314
315    /// Clear all progress bars
316    pub fn finish(&self) {
317        self.multi_progress.clear().ok();
318    }
319}
320
321async fn fetch_single_crd(client: &reqwest::Client, url: &str) -> Result<CRD> {
322    let content = client.get(url).send().await?.text().await?;
323
324    // Most CRDs are single YAML documents, try that first
325    if let Ok(crd) = serde_yaml::from_str::<CRD>(&content) {
326        return Ok(crd);
327    }
328
329    // If that fails, try parsing as a Value first to check kind
330    let value: serde_yaml::Value = serde_yaml::from_str(&content)?;
331    if value.get("kind")
332        == Some(&serde_yaml::Value::String(
333            "CustomResourceDefinition".to_string(),
334        ))
335    {
336        let crd: CRD = serde_yaml::from_value(value)?;
337        return Ok(crd);
338    }
339
340    Err(anyhow::anyhow!("Not a valid CRD"))
341}
342
343#[derive(Debug, serde::Deserialize)]
344struct GitHubContent {
345    name: String,
346    #[allow(dead_code)]
347    path: String,
348    #[serde(rename = "type")]
349    #[allow(dead_code)]
350    content_type: String,
351    download_url: Option<String>,
352}
353
354impl Default for CRDFetcher {
355    fn default() -> Self {
356        Self::new().unwrap()
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use pretty_assertions::assert_eq;
364    use serde_json::json;
365    use wiremock::{
366        matchers::{method, path},
367        Mock, MockServer, ResponseTemplate,
368    };
369
370    fn sample_crd() -> serde_json::Value {
371        json!({
372            "apiVersion": "apiextensions.k8s.io/v1",
373            "kind": "CustomResourceDefinition",
374            "metadata": {
375                "name": "compositions.apiextensions.crossplane.io"
376            },
377            "spec": {
378                "group": "apiextensions.crossplane.io",
379                "names": {
380                    "kind": "Composition",
381                    "plural": "compositions",
382                    "singular": "composition"
383                },
384                "versions": [{
385                    "name": "v1",
386                    "served": true,
387                    "storage": true,
388                    "schema": {
389                        "openAPIV3Schema": {
390                            "type": "object",
391                            "properties": {
392                                "spec": {
393                                    "type": "object",
394                                    "properties": {
395                                        "compositeTypeRef": {
396                                            "type": "object",
397                                            "properties": {
398                                                "apiVersion": {"type": "string"},
399                                                "kind": {"type": "string"}
400                                            }
401                                        }
402                                    }
403                                }
404                            }
405                        }
406                    }
407                }]
408            }
409        })
410    }
411
412    #[tokio::test]
413    async fn test_fetch_single_yaml_file() {
414        let mock_server = MockServer::start().await;
415
416        let crd_yaml = serde_yaml::to_string(&sample_crd()).unwrap();
417
418        Mock::given(method("GET"))
419            .and(path("/test.yaml"))
420            .respond_with(ResponseTemplate::new(200).set_body_string(crd_yaml))
421            .mount(&mock_server)
422            .await;
423
424        let fetcher = CRDFetcher::new().unwrap();
425        let url = format!("{}/test.yaml", &mock_server.uri());
426        let crds = fetcher.fetch_from_url(&url).await.unwrap();
427
428        assert_eq!(crds.len(), 1);
429        assert_eq!(crds[0].spec.group, "apiextensions.crossplane.io");
430        assert_eq!(crds[0].spec.names.kind, "Composition");
431    }
432
433    #[tokio::test]
434    async fn test_error_handling_404() {
435        let mock_server = MockServer::start().await;
436
437        Mock::given(method("GET"))
438            .and(path("/missing.yaml"))
439            .respond_with(ResponseTemplate::new(404))
440            .mount(&mock_server)
441            .await;
442
443        let fetcher = CRDFetcher::new().unwrap();
444        let url = format!("{}/missing.yaml", &mock_server.uri());
445        let result = fetcher.fetch_from_url(&url).await;
446
447        assert!(result.is_err());
448    }
449
450    #[tokio::test]
451    async fn test_error_handling_invalid_yaml() {
452        let mock_server = MockServer::start().await;
453
454        Mock::given(method("GET"))
455            .and(path("/invalid.yaml"))
456            .respond_with(ResponseTemplate::new(200).set_body_string("not: valid: yaml: content:"))
457            .mount(&mock_server)
458            .await;
459
460        let fetcher = CRDFetcher::new().unwrap();
461        let url = format!("{}/invalid.yaml", &mock_server.uri());
462        let result = fetcher.fetch_from_url(&url).await;
463
464        assert!(result.is_err());
465    }
466}