data_gov/
client.rs

1use futures::StreamExt;
2use indicatif::{ProgressBar, ProgressStyle};
3use is_terminal::IsTerminal;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use tokio::fs::File;
7use tokio::io::AsyncWriteExt;
8use url::Url;
9
10use crate::config::DataGovConfig;
11use crate::error::{DataGovError, Result};
12use data_gov_ckan::{
13    CkanClient,
14    models::{Package, PackageSearchResult, Resource},
15};
16
17/// Async client for exploring data.gov datasets.
18///
19/// `DataGovClient` layers ergonomic helpers on top of the lower-level
20/// [`data_gov_ckan::CkanClient`]. In addition to search and metadata lookups it
21/// handles download destinations, progress reporting, and colour-aware output
22/// that matches the `data-gov` CLI defaults.
23#[derive(Debug)]
24pub struct DataGovClient {
25    ckan: CkanClient,
26    config: DataGovConfig,
27    http_client: reqwest::Client,
28}
29
30impl DataGovClient {
31    /// Create a new DataGov client with default configuration
32    pub fn new() -> Result<Self> {
33        Self::with_config(DataGovConfig::new())
34    }
35
36    /// Create a new DataGov client with custom configuration
37    pub fn with_config(config: DataGovConfig) -> Result<Self> {
38        let ckan = CkanClient::new(config.ckan_config.clone());
39
40        // Create HTTP client with timeout for downloads
41        let http_client = reqwest::Client::builder()
42            .timeout(std::time::Duration::from_secs(config.download_timeout_secs))
43            .user_agent(&config.user_agent)
44            .build()?;
45
46        Ok(Self {
47            ckan,
48            config,
49            http_client,
50        })
51    }
52
53    // === Search and Discovery ===
54
55    /// Search for datasets on data.gov.
56    ///
57    /// # Arguments
58    /// * `query` - Search terms (searches titles, descriptions, tags)
59    /// * `limit` - Maximum number of results (default: 10, max: 1000)
60    /// * `offset` - Number of results to skip for pagination (default: 0)
61    /// * `organization` - Filter by organization name (optional)
62    /// * `format` - Filter by resource format (optional, e.g., "CSV", "JSON")
63    ///
64    /// # Examples
65    ///
66    /// Basic search:
67    /// ```rust,no_run
68    /// # use data_gov::DataGovClient;
69    /// # #[tokio::main]
70    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
71    /// let client = DataGovClient::new()?;
72    /// let results = client.search("climate data", Some(20), None, None, None).await?;
73    /// # Ok(())
74    /// # }
75    /// ```
76    ///
77    /// Search with filters:
78    /// ```rust,no_run
79    /// # use data_gov::DataGovClient;
80    /// # #[tokio::main]
81    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
82    /// # let client = DataGovClient::new()?;
83    /// let results = client.search("energy", Some(10), None, Some("doe-gov"), Some("CSV")).await?;
84    /// # Ok(())
85    /// # }
86    /// ```
87    pub async fn search(
88        &self,
89        query: &str,
90        limit: Option<i32>,
91        offset: Option<i32>,
92        organization: Option<&str>,
93        format: Option<&str>,
94    ) -> Result<PackageSearchResult> {
95        // Build filter query for advanced filtering
96        let fq = match (organization, format) {
97            (Some(org), Some(fmt)) => Some(format!(
98                r#"organization:"{}" AND res_format:"{}""#,
99                org, fmt
100            )),
101            (Some(org), None) => Some(format!(r#"organization:"{}""#, org)),
102            (None, Some(fmt)) => Some(format!(r#"res_format:"{}""#, fmt)),
103            (None, None) => None,
104        };
105
106        let result = self
107            .ckan
108            .package_search(Some(query), limit, offset, fq.as_deref())
109            .await?;
110
111        Ok(result)
112    }
113
114    /// Fetch the full `package_show` payload for a dataset.
115    pub async fn get_dataset(&self, dataset_id: &str) -> Result<Package> {
116        let package = self.ckan.package_show(dataset_id).await?;
117        Ok(package)
118    }
119
120    /// Fetch dataset name suggestions for interactive prompts.
121    pub async fn autocomplete_datasets(
122        &self,
123        partial: &str,
124        limit: Option<i32>,
125    ) -> Result<Vec<String>> {
126        let suggestions = self.ckan.dataset_autocomplete(Some(partial), limit).await?;
127        Ok(suggestions.into_iter().filter_map(|s| s.name).collect())
128    }
129
130    /// List the publisher slugs for government organizations.
131    pub async fn list_organizations(&self, limit: Option<i32>) -> Result<Vec<String>> {
132        let orgs = self.ckan.organization_list(None, limit, None).await?;
133        Ok(orgs)
134    }
135
136    /// Fetch organization name suggestions for interactive prompts.
137    pub async fn autocomplete_organizations(
138        &self,
139        partial: &str,
140        limit: Option<i32>,
141    ) -> Result<Vec<String>> {
142        let suggestions = self
143            .ckan
144            .organization_autocomplete(Some(partial), limit)
145            .await?;
146        Ok(suggestions.into_iter().filter_map(|s| s.name).collect())
147    }
148
149    // === Resource Management ===
150
151    /// Return resources that look like downloadable files.
152    ///
153    /// The returned list is filtered to resources that expose a direct URL, are
154    /// not marked as API endpoints, and advertise a file format.
155    pub fn get_downloadable_resources(package: &Package) -> Vec<Resource> {
156        package
157            .resources
158            .as_ref()
159            .unwrap_or(&Vec::new())
160            .iter()
161            .filter(|resource| {
162                // Has a URL and is not an API endpoint
163                resource.url.is_some()
164                    && resource.url_type.as_deref() != Some("api")
165                    && resource.format.is_some()
166            })
167            .cloned()
168            .collect()
169    }
170
171    /// Pick a filesystem-friendly filename for a resource download.
172    pub fn get_resource_filename(resource: &Resource, fallback_name: Option<&str>) -> String {
173        // Try resource name first
174        if let Some(name) = &resource.name {
175            if let Some(format) = &resource.format {
176                if name.ends_with(&format!(".{}", format.to_lowercase())) {
177                    return name.clone();
178                } else {
179                    return format!("{}.{}", name, format.to_lowercase());
180                }
181            }
182            return name.clone();
183        }
184
185        // Try to extract filename from URL
186        if let Some(url) = &resource.url {
187            if let Ok(parsed_url) = Url::parse(url) {
188                if let Some(segments) = parsed_url.path_segments() {
189                    if let Some(filename) = segments.last() {
190                        if !filename.is_empty() && filename.contains('.') {
191                            return filename.to_string();
192                        }
193                    }
194                }
195            }
196        }
197
198        // Use fallback with format extension
199        let base_name = fallback_name.unwrap_or("data");
200        if let Some(format) = &resource.format {
201            format!("{}.{}", base_name, format.to_lowercase())
202        } else {
203            format!("{}.dat", base_name)
204        }
205    }
206
207    // === File Downloads ===
208
209    /// Download a single resource into the dataset-specific directory.
210    ///
211    /// # Arguments
212    /// * `resource` - The resource to download
213    /// * `dataset_name` - Name of the dataset (used for subdirectory)
214    ///
215    /// Returns the path where the file was saved
216    pub async fn download_dataset_resource(
217        &self,
218        resource: &Resource,
219        dataset_name: &str,
220    ) -> Result<PathBuf> {
221        let dataset_dir = self.config.get_dataset_download_dir(dataset_name);
222        let filename = Self::get_resource_filename(resource, None);
223        let output_path = dataset_dir.join(filename);
224
225        let url = resource
226            .url
227            .as_ref()
228            .ok_or_else(|| DataGovError::resource_not_found("Resource has no URL"))?;
229
230        self.download_file(url, &output_path, resource.name.as_deref())
231            .await?;
232        Ok(output_path)
233    }
234
235    /// Download a single resource to a specific path.
236    ///
237    /// # Arguments
238    /// * `resource` - The resource to download
239    /// * `output_path` - Where to save the file (if None, uses base download directory)
240    ///
241    /// Returns the path where the file was saved
242    pub async fn download_resource(
243        &self,
244        resource: &Resource,
245        output_path: Option<PathBuf>,
246    ) -> Result<PathBuf> {
247        let url = resource
248            .url
249            .as_ref()
250            .ok_or_else(|| DataGovError::resource_not_found("Resource has no URL"))?;
251
252        let output_path = match output_path {
253            Some(path) => path,
254            None => {
255                let filename = Self::get_resource_filename(resource, None);
256                self.config.get_base_download_dir().join(filename)
257            }
258        };
259
260        self.download_file(url, &output_path, resource.name.as_deref())
261            .await?;
262        Ok(output_path)
263    }
264
265    /// Download multiple resources concurrently.
266    ///
267    /// Returns one [`Result`] per resource so callers can inspect partial failures.
268    pub async fn download_resources(
269        &self,
270        resources: &[Resource],
271        output_dir: Option<&Path>,
272    ) -> Vec<Result<PathBuf>> {
273        // For multiple resources, use simple concurrent downloads
274        if resources.len() > 1 {
275            if self.config.show_progress && std::env::var("NO_PROGRESS").is_err() {
276                println!("Downloading {} resources...", resources.len());
277            }
278
279            let output_dir = output_dir
280                .map(|p| p.to_path_buf())
281                .unwrap_or_else(|| self.config.get_base_download_dir());
282            let semaphore = Arc::new(tokio::sync::Semaphore::new(
283                self.config.max_concurrent_downloads,
284            ));
285
286            let mut futures = Vec::new();
287
288            for resource in resources {
289                let resource = resource.clone();
290                let output_dir = output_dir.clone();
291                let semaphore = semaphore.clone();
292                let http_client = self.http_client.clone();
293                let config = self.config.clone();
294
295                let future = async move {
296                    let _permit = semaphore.acquire().await.unwrap();
297
298                    let url = match &resource.url {
299                        Some(url) => url,
300                        None => {
301                            return Err(DataGovError::resource_not_found("Resource has no URL"));
302                        }
303                    };
304
305                    let filename = DataGovClient::get_resource_filename(&resource, None);
306                    let output_path = output_dir.join(filename);
307
308                    DataGovClient::download_file_simple(
309                        &http_client,
310                        &config,
311                        url,
312                        &output_path,
313                        resource.name.as_deref(),
314                    )
315                    .await?;
316
317                    Ok(output_path)
318                };
319
320                futures.push(future);
321            }
322
323            futures::future::join_all(futures).await
324        } else if resources.len() == 1 {
325            // For single resource, use regular download with progress bar
326            let resource = &resources[0];
327            let output_path = match output_dir {
328                Some(dir) => {
329                    let filename = Self::get_resource_filename(resource, None);
330                    Some(dir.join(filename))
331                }
332                None => None,
333            };
334
335            vec![self.download_resource(resource, output_path).await]
336        } else {
337            // No resources
338            vec![]
339        }
340    }
341
342    /// Download multiple resources into the dataset-specific directory.
343    ///
344    /// Returns one [`Result`] per resource so callers can inspect partial failures.
345    pub async fn download_dataset_resources(
346        &self,
347        resources: &[Resource],
348        dataset_name: &str,
349    ) -> Vec<Result<PathBuf>> {
350        // For multiple resources, use simple concurrent downloads
351        if resources.len() > 1 {
352            if self.config.show_progress && std::env::var("NO_PROGRESS").is_err() {
353                println!("Downloading {} resources...", resources.len());
354            }
355
356            let semaphore = Arc::new(tokio::sync::Semaphore::new(
357                self.config.max_concurrent_downloads,
358            ));
359            let mut futures = Vec::new();
360
361            for resource in resources {
362                let resource = resource.clone();
363                let dataset_name = dataset_name.to_string();
364                let semaphore = semaphore.clone();
365                let http_client = self.http_client.clone();
366                let config = self.config.clone();
367
368                let future = async move {
369                    let _permit = semaphore.acquire().await.unwrap();
370
371                    let url = match &resource.url {
372                        Some(url) => url,
373                        None => {
374                            return Err(DataGovError::resource_not_found("Resource has no URL"));
375                        }
376                    };
377
378                    let dataset_dir = config.get_base_download_dir().join(dataset_name);
379                    let filename = DataGovClient::get_resource_filename(&resource, None);
380                    let output_path = dataset_dir.join(filename);
381
382                    DataGovClient::download_file_simple(
383                        &http_client,
384                        &config,
385                        url,
386                        &output_path,
387                        resource.name.as_deref(),
388                    )
389                    .await?;
390
391                    Ok(output_path)
392                };
393
394                futures.push(future);
395            }
396
397            futures::future::join_all(futures).await
398        } else if resources.len() == 1 {
399            // For single resource, use regular download with progress bar
400            vec![
401                self.download_dataset_resource(&resources[0], dataset_name)
402                    .await,
403            ]
404        } else {
405            // No resources
406            vec![]
407        }
408    }
409
410    /// Download a file from a URL with progress tracking
411    async fn download_file(
412        &self,
413        url: &str,
414        output_path: &Path,
415        display_name: Option<&str>,
416    ) -> Result<()> {
417        // Create parent directories if they don't exist
418        if let Some(parent) = output_path.parent() {
419            tokio::fs::create_dir_all(parent).await?;
420        }
421
422        let response = self.http_client.get(url).send().await?;
423
424        if !response.status().is_success() {
425            return Err(DataGovError::download_error(format!(
426                "HTTP {} while downloading {}",
427                response.status(),
428                url
429            )));
430        }
431
432        let total_size = response.content_length();
433        let display_name = display_name.unwrap_or("file");
434
435        // Setup progress indication based on TTY and environment
436        let should_show_progress =
437            self.config.show_progress && std::env::var("NO_PROGRESS").is_err(); // Respect NO_PROGRESS env var
438
439        let (progress_bar, show_simple_progress) = if should_show_progress {
440            if std::io::stdout().is_terminal() && std::env::var("FORCE_SIMPLE_PROGRESS").is_err() {
441                // TTY: Show fancy progress bar (unless forced simple)
442                let pb = if let Some(size) = total_size {
443                    ProgressBar::new(size)
444                } else {
445                    ProgressBar::new_spinner()
446                };
447
448                let template = if total_size.is_some() {
449                    "{msg} [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})"
450                } else {
451                    "{msg} [{spinner:.cyan/blue}] {bytes} ({bytes_per_sec})"
452                };
453
454                pb.set_style(
455                    ProgressStyle::default_bar()
456                        .template(template)
457                        .unwrap_or_else(|_| {
458                            if total_size.is_some() {
459                                ProgressStyle::default_bar().progress_chars("█▉▊▋▌▍▎▏ ")
460                            } else {
461                                ProgressStyle::default_spinner()
462                            }
463                        })
464                        .progress_chars("█▉▊▋▌▍▎▏ "),
465                );
466                pb.set_message(format!("Downloading {}", display_name));
467                (Some(pb), false)
468            } else {
469                // Non-TTY or forced simple: Show simple text progress
470                if total_size.is_some() {
471                    println!(
472                        "Downloading {} ({} bytes)...",
473                        display_name,
474                        total_size.unwrap()
475                    );
476                } else {
477                    println!("Downloading {} ...", display_name);
478                }
479                (None, true)
480            }
481        } else {
482            // Progress disabled
483            (None, false)
484        };
485
486        let mut file = File::create(output_path).await?;
487        let mut stream = response.bytes_stream();
488        let mut downloaded = 0u64;
489
490        while let Some(chunk) = stream.next().await {
491            let chunk = chunk?;
492            file.write_all(&chunk).await?;
493            downloaded += chunk.len() as u64;
494
495            if let Some(pb) = &progress_bar {
496                pb.set_position(downloaded);
497            }
498        }
499
500        if let Some(pb) = progress_bar {
501            // For unknown size downloads, set the final position before finishing
502            if total_size.is_none() {
503                pb.set_length(downloaded);
504                pb.set_position(downloaded);
505            }
506            pb.finish_with_message(format!("Downloaded {}", display_name));
507        } else if show_simple_progress {
508            // For non-TTY, show completion message
509            println!("✓ Downloaded {}", display_name);
510        }
511
512        Ok(())
513    }
514
515    /// Simple download method for concurrent downloads (no progress bars to avoid conflicts)
516    async fn download_file_simple(
517        http_client: &reqwest::Client,
518        config: &DataGovConfig,
519        url: &str,
520        output_path: &Path,
521        display_name: Option<&str>,
522    ) -> Result<()> {
523        // Create parent directories if they don't exist
524        if let Some(parent) = output_path.parent() {
525            tokio::fs::create_dir_all(parent).await?;
526        }
527
528        let display_name = display_name.unwrap_or("file");
529
530        // Show simple text progress for concurrent downloads
531        if config.show_progress && std::env::var("NO_PROGRESS").is_err() {
532            println!("Downloading {} ...", display_name);
533        }
534
535        let response = http_client.get(url).send().await?;
536
537        if !response.status().is_success() {
538            return Err(DataGovError::download_error(format!(
539                "HTTP {} while downloading {}",
540                response.status(),
541                url
542            )));
543        }
544
545        let mut file = File::create(output_path).await?;
546        let mut stream = response.bytes_stream();
547
548        while let Some(chunk) = stream.next().await {
549            let chunk = chunk?;
550            file.write_all(&chunk).await?;
551        }
552
553        // Show completion message
554        if config.show_progress && std::env::var("NO_PROGRESS").is_err() {
555            println!("✓ Downloaded {}", display_name);
556        }
557
558        Ok(())
559    }
560
561    // === Utility Methods ===
562
563    /// Check if the base download directory exists and is writable
564    pub async fn validate_download_dir(&self) -> Result<()> {
565        let base_dir = self.config.get_base_download_dir();
566
567        if !base_dir.exists() {
568            tokio::fs::create_dir_all(&base_dir).await?;
569        }
570
571        if !base_dir.is_dir() {
572            return Err(DataGovError::config_error(format!(
573                "Download path is not a directory: {:?}",
574                base_dir
575            )));
576        }
577
578        // Test write permissions by creating a temporary file
579        let test_file = base_dir.join(".write_test");
580        tokio::fs::write(&test_file, b"test").await?;
581        tokio::fs::remove_file(&test_file).await?;
582
583        Ok(())
584    }
585
586    /// Get the current base download directory
587    pub fn download_dir(&self) -> PathBuf {
588        self.config.get_base_download_dir()
589    }
590
591    /// Get the underlying CKAN client for advanced operations
592    pub fn ckan_client(&self) -> &CkanClient {
593        &self.ckan
594    }
595}
596
597impl Default for DataGovClient {
598    fn default() -> Self {
599        Self::new().expect("Failed to create default DataGovClient")
600    }
601}