data_gov/
client.rs

1use futures::StreamExt;
2use indicatif::{ProgressBar, ProgressStyle};
3use is_terminal::IsTerminal;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use tokio::fs::File;
7use tokio::io::AsyncWriteExt;
8use url::Url;
9
10use crate::config::DataGovConfig;
11use crate::error::{DataGovError, Result};
12use data_gov_ckan::{
13    CkanClient,
14    models::{Package, PackageSearchResult, Resource},
15};
16
17/// High-level client for interacting with data.gov
18///
19/// This client wraps the CKAN client and provides additional functionality
20/// for downloading resources, managing files, and working with data.gov specifically.
21#[derive(Debug)]
22pub struct DataGovClient {
23    ckan: CkanClient,
24    config: DataGovConfig,
25    http_client: reqwest::Client,
26}
27
28impl DataGovClient {
29    /// Create a new DataGov client with default configuration
30    pub fn new() -> Result<Self> {
31        Self::with_config(DataGovConfig::new())
32    }
33
34    /// Create a new DataGov client with custom configuration
35    pub fn with_config(config: DataGovConfig) -> Result<Self> {
36        let ckan = CkanClient::new(config.ckan_config.clone());
37
38        // Create HTTP client with timeout for downloads
39        let http_client = reqwest::Client::builder()
40            .timeout(std::time::Duration::from_secs(config.download_timeout_secs))
41            .user_agent(&config.user_agent)
42            .build()?;
43
44        Ok(Self {
45            ckan,
46            config,
47            http_client,
48        })
49    }
50
51    // === Search and Discovery ===
52
53    /// Search for datasets on data.gov
54    ///
55    /// # Arguments
56    /// * `query` - Search terms (searches titles, descriptions, tags)
57    /// * `limit` - Maximum number of results (default: 10, max: 1000)
58    /// * `offset` - Number of results to skip for pagination (default: 0)
59    /// * `organization` - Filter by organization name (optional)
60    /// * `format` - Filter by resource format (optional, e.g., "CSV", "JSON")
61    ///
62    /// # Examples
63    ///
64    /// Basic search:
65    /// ```rust,no_run
66    /// # use data_gov::DataGovClient;
67    /// # #[tokio::main]
68    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
69    /// let client = DataGovClient::new()?;
70    /// let results = client.search("climate data", Some(20), None, None, None).await?;
71    /// # Ok(())
72    /// # }
73    /// ```
74    ///
75    /// Search with filters:
76    /// ```rust,no_run
77    /// # use data_gov::DataGovClient;
78    /// # #[tokio::main]
79    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
80    /// # let client = DataGovClient::new()?;
81    /// let results = client.search("energy", Some(10), None, Some("doe-gov"), Some("CSV")).await?;
82    /// # Ok(())
83    /// # }
84    /// ```
85    pub async fn search(
86        &self,
87        query: &str,
88        limit: Option<i32>,
89        offset: Option<i32>,
90        organization: Option<&str>,
91        format: Option<&str>,
92    ) -> Result<PackageSearchResult> {
93        // Build filter query for advanced filtering
94        let fq = match (organization, format) {
95            (Some(org), Some(fmt)) => Some(format!(
96                r#"organization:"{}" AND res_format:"{}""#,
97                org, fmt
98            )),
99            (Some(org), None) => Some(format!(r#"organization:"{}""#, org)),
100            (None, Some(fmt)) => Some(format!(r#"res_format:"{}""#, fmt)),
101            (None, None) => None,
102        };
103
104        let result = self
105            .ckan
106            .package_search(Some(query), limit, offset, fq.as_deref())
107            .await?;
108
109        Ok(result)
110    }
111
112    /// Get detailed information about a dataset
113    pub async fn get_dataset(&self, dataset_id: &str) -> Result<Package> {
114        let package = self.ckan.package_show(dataset_id).await?;
115        Ok(package)
116    }
117
118    /// Get autocomplete suggestions for dataset names
119    pub async fn autocomplete_datasets(
120        &self,
121        partial: &str,
122        limit: Option<i32>,
123    ) -> Result<Vec<String>> {
124        let suggestions = self.ckan.dataset_autocomplete(Some(partial), limit).await?;
125        Ok(suggestions.into_iter().filter_map(|s| s.name).collect())
126    }
127
128    /// Get list of organizations (government agencies)
129    pub async fn list_organizations(&self, limit: Option<i32>) -> Result<Vec<String>> {
130        let orgs = self.ckan.organization_list(None, limit, None).await?;
131        Ok(orgs)
132    }
133
134    /// Get autocomplete suggestions for organizations
135    pub async fn autocomplete_organizations(
136        &self,
137        partial: &str,
138        limit: Option<i32>,
139    ) -> Result<Vec<String>> {
140        let suggestions = self
141            .ckan
142            .organization_autocomplete(Some(partial), limit)
143            .await?;
144        Ok(suggestions.into_iter().filter_map(|s| s.name).collect())
145    }
146
147    // === Resource Management ===
148
149    /// Find downloadable resources in a dataset
150    ///
151    /// Returns a list of resources that have URLs and are likely downloadable files
152    pub fn get_downloadable_resources(package: &Package) -> Vec<Resource> {
153        package
154            .resources
155            .as_ref()
156            .unwrap_or(&Vec::new())
157            .iter()
158            .filter(|resource| {
159                // Has a URL and is not an API endpoint
160                resource.url.is_some()
161                    && resource.url_type.as_deref() != Some("api")
162                    && resource.format.is_some()
163            })
164            .cloned()
165            .collect()
166    }
167
168    /// Get the best download filename for a resource
169    pub fn get_resource_filename(resource: &Resource, fallback_name: Option<&str>) -> String {
170        // Try resource name first
171        if let Some(name) = &resource.name {
172            if let Some(format) = &resource.format {
173                if name.ends_with(&format!(".{}", format.to_lowercase())) {
174                    return name.clone();
175                } else {
176                    return format!("{}.{}", name, format.to_lowercase());
177                }
178            }
179            return name.clone();
180        }
181
182        // Try to extract filename from URL
183        if let Some(url) = &resource.url {
184            if let Ok(parsed_url) = Url::parse(url) {
185                if let Some(segments) = parsed_url.path_segments() {
186                    if let Some(filename) = segments.last() {
187                        if !filename.is_empty() && filename.contains('.') {
188                            return filename.to_string();
189                        }
190                    }
191                }
192            }
193        }
194
195        // Use fallback with format extension
196        let base_name = fallback_name.unwrap_or("data");
197        if let Some(format) = &resource.format {
198            format!("{}.{}", base_name, format.to_lowercase())
199        } else {
200            format!("{}.dat", base_name)
201        }
202    }
203
204    // === File Downloads ===
205
206    /// Download a resource from a dataset to its dataset-specific directory
207    ///
208    /// # Arguments
209    /// * `resource` - The resource to download
210    /// * `dataset_name` - Name of the dataset (used for subdirectory)
211    ///
212    /// Returns the path where the file was saved
213    pub async fn download_dataset_resource(
214        &self,
215        resource: &Resource,
216        dataset_name: &str,
217    ) -> Result<PathBuf> {
218        let dataset_dir = self.config.get_dataset_download_dir(dataset_name);
219        let filename = Self::get_resource_filename(resource, None);
220        let output_path = dataset_dir.join(filename);
221
222        let url = resource
223            .url
224            .as_ref()
225            .ok_or_else(|| DataGovError::resource_not_found("Resource has no URL"))?;
226
227        self.download_file(url, &output_path, resource.name.as_deref())
228            .await?;
229        Ok(output_path)
230    }
231
232    /// Download a resource to a file
233    ///
234    /// # Arguments
235    /// * `resource` - The resource to download
236    /// * `output_path` - Where to save the file (if None, uses base download directory)
237    ///
238    /// Returns the path where the file was saved
239    pub async fn download_resource(
240        &self,
241        resource: &Resource,
242        output_path: Option<PathBuf>,
243    ) -> Result<PathBuf> {
244        let url = resource
245            .url
246            .as_ref()
247            .ok_or_else(|| DataGovError::resource_not_found("Resource has no URL"))?;
248
249        let output_path = match output_path {
250            Some(path) => path,
251            None => {
252                let filename = Self::get_resource_filename(resource, None);
253                self.config.get_base_download_dir().join(filename)
254            }
255        };
256
257        self.download_file(url, &output_path, resource.name.as_deref())
258            .await?;
259        Ok(output_path)
260    }
261
262    /// Download multiple resources concurrently
263    ///
264    /// Returns a vector of results, each containing either the download path or an error
265    pub async fn download_resources(
266        &self,
267        resources: &[Resource],
268        output_dir: Option<&Path>,
269    ) -> Vec<Result<PathBuf>> {
270        // For multiple resources, use simple concurrent downloads
271        if resources.len() > 1 {
272            if self.config.show_progress && std::env::var("NO_PROGRESS").is_err() {
273                println!("Downloading {} resources...", resources.len());
274            }
275
276            let output_dir = output_dir
277                .map(|p| p.to_path_buf())
278                .unwrap_or_else(|| self.config.get_base_download_dir());
279            let semaphore = Arc::new(tokio::sync::Semaphore::new(
280                self.config.max_concurrent_downloads,
281            ));
282
283            let mut futures = Vec::new();
284
285            for resource in resources {
286                let resource = resource.clone();
287                let output_dir = output_dir.clone();
288                let semaphore = semaphore.clone();
289                let http_client = self.http_client.clone();
290                let config = self.config.clone();
291
292                let future = async move {
293                    let _permit = semaphore.acquire().await.unwrap();
294
295                    let url = match &resource.url {
296                        Some(url) => url,
297                        None => {
298                            return Err(DataGovError::resource_not_found("Resource has no URL"));
299                        }
300                    };
301
302                    let filename = DataGovClient::get_resource_filename(&resource, None);
303                    let output_path = output_dir.join(filename);
304
305                    DataGovClient::download_file_simple(
306                        &http_client,
307                        &config,
308                        url,
309                        &output_path,
310                        resource.name.as_deref(),
311                    )
312                    .await?;
313
314                    Ok(output_path)
315                };
316
317                futures.push(future);
318            }
319
320            futures::future::join_all(futures).await
321        } else if resources.len() == 1 {
322            // For single resource, use regular download with progress bar
323            let resource = &resources[0];
324            let output_path = match output_dir {
325                Some(dir) => {
326                    let filename = Self::get_resource_filename(resource, None);
327                    Some(dir.join(filename))
328                }
329                None => None,
330            };
331
332            vec![self.download_resource(resource, output_path).await]
333        } else {
334            // No resources
335            vec![]
336        }
337    }
338
339    /// Download multiple resources from a dataset to its dataset-specific directory
340    ///
341    /// Returns a vector of results, each containing either the download path or an error
342    pub async fn download_dataset_resources(
343        &self,
344        resources: &[Resource],
345        dataset_name: &str,
346    ) -> Vec<Result<PathBuf>> {
347        // For multiple resources, use simple concurrent downloads
348        if resources.len() > 1 {
349            if self.config.show_progress && std::env::var("NO_PROGRESS").is_err() {
350                println!("Downloading {} resources...", resources.len());
351            }
352
353            let semaphore = Arc::new(tokio::sync::Semaphore::new(
354                self.config.max_concurrent_downloads,
355            ));
356            let mut futures = Vec::new();
357
358            for resource in resources {
359                let resource = resource.clone();
360                let dataset_name = dataset_name.to_string();
361                let semaphore = semaphore.clone();
362                let http_client = self.http_client.clone();
363                let config = self.config.clone();
364
365                let future = async move {
366                    let _permit = semaphore.acquire().await.unwrap();
367
368                    let url = match &resource.url {
369                        Some(url) => url,
370                        None => {
371                            return Err(DataGovError::resource_not_found("Resource has no URL"));
372                        }
373                    };
374
375                    let dataset_dir = config.get_base_download_dir().join(dataset_name);
376                    let filename = DataGovClient::get_resource_filename(&resource, None);
377                    let output_path = dataset_dir.join(filename);
378
379                    DataGovClient::download_file_simple(
380                        &http_client,
381                        &config,
382                        url,
383                        &output_path,
384                        resource.name.as_deref(),
385                    )
386                    .await?;
387
388                    Ok(output_path)
389                };
390
391                futures.push(future);
392            }
393
394            futures::future::join_all(futures).await
395        } else if resources.len() == 1 {
396            // For single resource, use regular download with progress bar
397            vec![
398                self.download_dataset_resource(&resources[0], dataset_name)
399                    .await,
400            ]
401        } else {
402            // No resources
403            vec![]
404        }
405    }
406
407    /// Download a file from a URL with progress tracking
408    async fn download_file(
409        &self,
410        url: &str,
411        output_path: &Path,
412        display_name: Option<&str>,
413    ) -> Result<()> {
414        // Create parent directories if they don't exist
415        if let Some(parent) = output_path.parent() {
416            tokio::fs::create_dir_all(parent).await?;
417        }
418
419        let response = self.http_client.get(url).send().await?;
420
421        if !response.status().is_success() {
422            return Err(DataGovError::download_error(format!(
423                "HTTP {} while downloading {}",
424                response.status(),
425                url
426            )));
427        }
428
429        let total_size = response.content_length();
430        let display_name = display_name.unwrap_or("file");
431
432        // Setup progress indication based on TTY and environment
433        let should_show_progress =
434            self.config.show_progress && std::env::var("NO_PROGRESS").is_err(); // Respect NO_PROGRESS env var
435
436        let (progress_bar, show_simple_progress) = if should_show_progress {
437            if std::io::stdout().is_terminal() && std::env::var("FORCE_SIMPLE_PROGRESS").is_err() {
438                // TTY: Show fancy progress bar (unless forced simple)
439                let pb = if let Some(size) = total_size {
440                    ProgressBar::new(size)
441                } else {
442                    ProgressBar::new_spinner()
443                };
444
445                let template = if total_size.is_some() {
446                    "{msg} [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})"
447                } else {
448                    "{msg} [{spinner:.cyan/blue}] {bytes} ({bytes_per_sec})"
449                };
450
451                pb.set_style(
452                    ProgressStyle::default_bar()
453                        .template(template)
454                        .unwrap_or_else(|_| {
455                            if total_size.is_some() {
456                                ProgressStyle::default_bar().progress_chars("█▉▊▋▌▍▎▏ ")
457                            } else {
458                                ProgressStyle::default_spinner()
459                            }
460                        })
461                        .progress_chars("█▉▊▋▌▍▎▏ "),
462                );
463                pb.set_message(format!("Downloading {}", display_name));
464                (Some(pb), false)
465            } else {
466                // Non-TTY or forced simple: Show simple text progress
467                if total_size.is_some() {
468                    println!(
469                        "Downloading {} ({} bytes)...",
470                        display_name,
471                        total_size.unwrap()
472                    );
473                } else {
474                    println!("Downloading {} ...", display_name);
475                }
476                (None, true)
477            }
478        } else {
479            // Progress disabled
480            (None, false)
481        };
482
483        let mut file = File::create(output_path).await?;
484        let mut stream = response.bytes_stream();
485        let mut downloaded = 0u64;
486
487        while let Some(chunk) = stream.next().await {
488            let chunk = chunk?;
489            file.write_all(&chunk).await?;
490            downloaded += chunk.len() as u64;
491
492            if let Some(pb) = &progress_bar {
493                pb.set_position(downloaded);
494            }
495        }
496
497        if let Some(pb) = progress_bar {
498            // For unknown size downloads, set the final position before finishing
499            if total_size.is_none() {
500                pb.set_length(downloaded);
501                pb.set_position(downloaded);
502            }
503            pb.finish_with_message(format!("Downloaded {}", display_name));
504        } else if show_simple_progress {
505            // For non-TTY, show completion message
506            println!("✓ Downloaded {}", display_name);
507        }
508
509        Ok(())
510    }
511
512    /// Simple download method for concurrent downloads (no progress bars to avoid conflicts)
513    async fn download_file_simple(
514        http_client: &reqwest::Client,
515        config: &DataGovConfig,
516        url: &str,
517        output_path: &Path,
518        display_name: Option<&str>,
519    ) -> Result<()> {
520        // Create parent directories if they don't exist
521        if let Some(parent) = output_path.parent() {
522            tokio::fs::create_dir_all(parent).await?;
523        }
524
525        let display_name = display_name.unwrap_or("file");
526
527        // Show simple text progress for concurrent downloads
528        if config.show_progress && std::env::var("NO_PROGRESS").is_err() {
529            println!("Downloading {} ...", display_name);
530        }
531
532        let response = http_client.get(url).send().await?;
533
534        if !response.status().is_success() {
535            return Err(DataGovError::download_error(format!(
536                "HTTP {} while downloading {}",
537                response.status(),
538                url
539            )));
540        }
541
542        let mut file = File::create(output_path).await?;
543        let mut stream = response.bytes_stream();
544
545        while let Some(chunk) = stream.next().await {
546            let chunk = chunk?;
547            file.write_all(&chunk).await?;
548        }
549
550        // Show completion message
551        if config.show_progress && std::env::var("NO_PROGRESS").is_err() {
552            println!("✓ Downloaded {}", display_name);
553        }
554
555        Ok(())
556    }
557
558    // === Utility Methods ===
559
560    /// Check if the base download directory exists and is writable
561    pub async fn validate_download_dir(&self) -> Result<()> {
562        let base_dir = self.config.get_base_download_dir();
563
564        if !base_dir.exists() {
565            tokio::fs::create_dir_all(&base_dir).await?;
566        }
567
568        if !base_dir.is_dir() {
569            return Err(DataGovError::config_error(format!(
570                "Download path is not a directory: {:?}",
571                base_dir
572            )));
573        }
574
575        // Test write permissions by creating a temporary file
576        let test_file = base_dir.join(".write_test");
577        tokio::fs::write(&test_file, b"test").await?;
578        tokio::fs::remove_file(&test_file).await?;
579
580        Ok(())
581    }
582
583    /// Get the current base download directory
584    pub fn download_dir(&self) -> PathBuf {
585        self.config.get_base_download_dir()
586    }
587
588    /// Get the underlying CKAN client for advanced operations
589    pub fn ckan_client(&self) -> &CkanClient {
590        &self.ckan
591    }
592}
593
594impl Default for DataGovClient {
595    fn default() -> Self {
596        Self::new().expect("Failed to create default DataGovClient")
597    }
598}