data_gov/
client.rs

1use futures::StreamExt;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use tokio::fs::File;
5use tokio::io::AsyncWriteExt;
6use url::Url;
7
8use crate::config::DataGovConfig;
9use crate::error::{DataGovError, Result};
10use crate::ui::{
11    DownloadBatch, DownloadFailed, DownloadFinished, DownloadProgress, DownloadStarted,
12    StatusReporter,
13};
14use data_gov_ckan::{
15    CkanClient,
16    models::{Package, PackageSearchResult, Resource},
17};
18
19/// Async client for exploring data.gov datasets.
20///
21/// `DataGovClient` layers ergonomic helpers on top of the lower-level
22/// [`data_gov_ckan::CkanClient`]. In addition to search and metadata lookups it
23/// handles download destinations, progress reporting, and colour-aware output
24/// that matches the `data-gov` CLI defaults.
25#[derive(Debug)]
26pub struct DataGovClient {
27    ckan: CkanClient,
28    config: DataGovConfig,
29    http_client: reqwest::Client,
30}
31
32impl DataGovClient {
33    /// Create a new DataGov client with default configuration
34    pub fn new() -> Result<Self> {
35        Self::with_config(DataGovConfig::new())
36    }
37
38    /// Create a new DataGov client with custom configuration
39    pub fn with_config(config: DataGovConfig) -> Result<Self> {
40        let ckan = CkanClient::new(config.ckan_config.clone());
41
42        // Create HTTP client with timeout for downloads
43        let http_client = reqwest::Client::builder()
44            .timeout(std::time::Duration::from_secs(config.download_timeout_secs))
45            .user_agent(&config.user_agent)
46            .build()?;
47
48        Ok(Self {
49            ckan,
50            config,
51            http_client,
52        })
53    }
54
55    // === Search and Discovery ===
56
57    /// Search for datasets on data.gov.
58    ///
59    /// # Arguments
60    /// * `query` - Search terms (searches titles, descriptions, tags)
61    ///
62    /// Solr syntax: The `query` string may include Solr-style expressions such as
63    /// wildcards (e.g. `climat*`), phrase searches (`"air quality"`), and boolean
64    /// operators (`AND`, `OR`, `NOT`). When combined with structured filters the
65    /// underlying CKAN `package_search` endpoint interprets `q` and `fq` using
66    /// Solr semantics, so complex queries are supported.
67    /// * `limit` - Maximum number of results (default: 10, max: 1000)
68    /// * `offset` - Number of results to skip for pagination (default: 0)
69    /// * `organization` - Filter by organization name (optional)
70    /// * `format` - Filter by resource format (optional, e.g., "CSV", "JSON")
71    ///
72    /// # Examples
73    ///
74    /// Basic search:
75    /// ```rust,no_run
76    /// # use data_gov::DataGovClient;
77    /// # #[tokio::main]
78    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
79    /// let client = DataGovClient::new()?;
80    /// let results = client.search("climate data", Some(20), None, None, None).await?;
81    /// # Ok(())
82    /// # }
83    /// ```
84    ///
85    /// Search with filters:
86    /// ```rust,no_run
87    /// # use data_gov::DataGovClient;
88    /// # #[tokio::main]
89    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
90    /// # let client = DataGovClient::new()?;
91    /// let results = client.search("energy", Some(10), None, Some("doe-gov"), Some("CSV")).await?;
92    /// # Ok(())
93    /// # }
94    /// ```
95    pub async fn search(
96        &self,
97        query: &str,
98        limit: Option<i32>,
99        offset: Option<i32>,
100        organization: Option<&str>,
101        format: Option<&str>,
102    ) -> Result<PackageSearchResult> {
103        // Build filter query for advanced filtering
104        // Use quotes around values to ensure proper Solr parsing
105        let fq = match (organization, format) {
106            (Some(org), Some(fmt)) => Some(format!(
107                r#"organization:"{}" AND res_format:"{}""#,
108                org, fmt
109            )),
110            (Some(org), None) => Some(format!(r#"organization:"{}""#, org)),
111            (None, Some(fmt)) => Some(format!(r#"res_format:"{}""#, fmt)),
112            (None, None) => None,
113        };
114
115        // Only pass query if it's non-empty (don't use "*" as it interferes with filters)
116        let query_param = if query.is_empty() { None } else { Some(query) };
117
118        let result = self
119            .ckan
120            .package_search(query_param, limit, offset, fq.as_deref())
121            .await?;
122
123        Ok(result)
124    }
125
126    /// Fetch the full `package_show` payload for a dataset.
127    pub async fn get_dataset(&self, dataset_id: &str) -> Result<Package> {
128        let package = self.ckan.package_show(dataset_id).await?;
129        Ok(package)
130    }
131
132    /// Fetch dataset name suggestions for interactive prompts.
133    pub async fn autocomplete_datasets(
134        &self,
135        partial: &str,
136        limit: Option<i32>,
137    ) -> Result<Vec<String>> {
138        let suggestions = self.ckan.dataset_autocomplete(Some(partial), limit).await?;
139        Ok(suggestions.into_iter().filter_map(|s| s.name).collect())
140    }
141
142    /// List the publisher slugs for government organizations.
143    pub async fn list_organizations(&self, limit: Option<i32>) -> Result<Vec<String>> {
144        let orgs = self.ckan.organization_list(None, limit, None).await?;
145        Ok(orgs)
146    }
147
148    /// Fetch organization name suggestions for interactive prompts.
149    pub async fn autocomplete_organizations(
150        &self,
151        partial: &str,
152        limit: Option<i32>,
153    ) -> Result<Vec<String>> {
154        let suggestions = self
155            .ckan
156            .organization_autocomplete(Some(partial), limit)
157            .await?;
158        Ok(suggestions.into_iter().filter_map(|s| s.name).collect())
159    }
160
161    // === Resource Management ===
162
163    /// Return resources that look like downloadable files.
164    ///
165    /// The returned list is filtered to resources that expose a direct URL, are
166    /// not marked as API endpoints, and advertise a file format.
167    pub fn get_downloadable_resources(package: &Package) -> Vec<Resource> {
168        package
169            .resources
170            .as_ref()
171            .unwrap_or(&Vec::new())
172            .iter()
173            .filter(|resource| {
174                // Has a URL and is not an API endpoint
175                resource.url.is_some()
176                    && resource.url_type.as_deref() != Some("api")
177                    && resource.format.is_some()
178            })
179            .cloned()
180            .collect()
181    }
182
183    /// Pick a filesystem-friendly filename for a resource download.
184    /// Pick a filesystem-friendly filename for a resource download.
185    ///
186    /// # Arguments
187    /// * `resource` - The resource to generate a filename for
188    /// * `fallback_name` - Optional fallback name if resource has no name
189    /// * `resource_index` - Optional index to append to prevent conflicts when multiple resources have the same name
190    ///
191    /// When downloading multiple resources, the index should be provided to ensure unique filenames
192    /// even when resources have duplicate names.
193    pub fn get_resource_filename(
194        resource: &Resource,
195        fallback_name: Option<&str>,
196        resource_index: Option<usize>,
197    ) -> String {
198        // Determine base filename and whether it has an extension
199        let (base_filename, has_extension) = if let Some(name) = &resource.name {
200            if let Some(format) = &resource.format {
201                let format_lower = format.to_lowercase();
202                if name.ends_with(&format!(".{}", format_lower)) {
203                    (name.clone(), true)
204                } else {
205                    (format!("{}.{}", name, format_lower), true)
206                }
207            } else {
208                (name.clone(), false)
209            }
210        } else if let Some(url) = &resource.url
211            && let Ok(parsed_url) = Url::parse(url)
212            && let Some(mut segments) = parsed_url.path_segments()
213            && let Some(filename) = segments.next_back()
214            && !filename.is_empty()
215            && filename.contains('.')
216        {
217            (filename.to_string(), true)
218        } else {
219            // Use fallback with format extension
220            let base_name = fallback_name.unwrap_or("data");
221            if let Some(format) = &resource.format {
222                (format!("{}.{}", base_name, format.to_lowercase()), true)
223            } else {
224                (format!("{}.dat", base_name), true)
225            }
226        };
227
228        // Append resource index if provided (to handle duplicate names)
229        if let Some(index) = resource_index {
230            if has_extension {
231                // Insert index before the extension
232                if let Some(dot_pos) = base_filename.rfind('.') {
233                    let (name, ext) = base_filename.split_at(dot_pos);
234                    return format!("{}-{}{}", name, index, ext);
235                }
236            }
237            // No extension or couldn't find dot, just append
238            format!("{}-{}", base_filename, index)
239        } else {
240            base_filename
241        }
242    }
243
244    // === File Downloads ===
245
246    /// Download a single resource to the specified directory.
247    ///
248    /// # Arguments
249    /// * `resource` - The resource to download
250    /// * `output_dir` - Directory where the file will be saved. If None, uses the base download directory.
251    ///
252    /// Returns the full path where the file was saved.
253    pub async fn download_resource(
254        &self,
255        resource: &Resource,
256        output_dir: Option<&Path>,
257    ) -> Result<PathBuf> {
258        let url = match resource.url.as_deref() {
259            Some(url) => url,
260            None => {
261                if let Some(reporter) = self.config.status_reporter.as_ref() {
262                    let event = DownloadFailed {
263                        resource_name: resource.name.clone(),
264                        dataset_name: None,
265                        output_path: None,
266                        error: "Resource has no URL".to_string(),
267                    };
268                    reporter.on_download_failed(&event);
269                }
270                return Err(DataGovError::resource_not_found("Resource has no URL"));
271            }
272        };
273
274        let output_dir = output_dir
275            .map(|p| p.to_path_buf())
276            .unwrap_or_else(|| self.config.get_base_download_dir());
277        // No resource index needed for single download
278        let filename = Self::get_resource_filename(resource, None, None);
279        let output_path = output_dir.join(filename);
280
281        Self::perform_download(
282            &self.http_client,
283            url,
284            &output_path,
285            resource.name.clone(),
286            None,
287            self.reporter(),
288        )
289        .await?;
290
291        Ok(output_path)
292    }
293
294    /// Download multiple resources concurrently to the specified directory.
295    ///
296    /// # Arguments
297    /// * `resources` - Slice of resources to download
298    /// * `output_dir` - Directory where files will be saved. If None, uses the base download directory.
299    ///
300    /// Returns one [`Result`] per resource so callers can inspect partial failures.
301    pub async fn download_resources(
302        &self,
303        resources: &[Resource],
304        output_dir: Option<&Path>,
305    ) -> Vec<Result<PathBuf>> {
306        if resources.is_empty() {
307            return vec![];
308        }
309
310        if resources.len() == 1 {
311            return vec![self.download_resource(&resources[0], output_dir).await];
312        }
313
314        // Multiple resources - use concurrent downloads with semaphore
315        if let Some(reporter) = self.config.status_reporter.as_ref() {
316            let event = DownloadBatch {
317                resource_count: resources.len(),
318                dataset_name: None,
319            };
320            reporter.on_download_batch(&event);
321        }
322
323        let output_dir = output_dir
324            .map(|p| p.to_path_buf())
325            .unwrap_or_else(|| self.config.get_base_download_dir());
326
327        let semaphore = Arc::new(tokio::sync::Semaphore::new(
328            self.config.max_concurrent_downloads,
329        ));
330
331        let status_reporter = self.reporter();
332        let mut futures = Vec::with_capacity(resources.len());
333
334        for (resource_index, resource) in resources.iter().enumerate() {
335            let resource = resource.clone();
336            let output_dir = output_dir.clone();
337            let semaphore = semaphore.clone();
338            let http_client = self.http_client.clone();
339            let status_reporter = status_reporter.clone();
340
341            let future = async move {
342                let _permit = match semaphore.acquire().await {
343                    Ok(permit) => permit,
344                    Err(e) => {
345                        if let Some(reporter) = status_reporter.as_ref() {
346                            let event = DownloadFailed {
347                                resource_name: resource.name.clone(),
348                                dataset_name: None,
349                                output_path: None,
350                                error: format!("Failed to acquire download slot: {}", e),
351                            };
352                            reporter.on_download_failed(&event);
353                        }
354                        return Err(DataGovError::download_error(format!(
355                            "Semaphore error: {}",
356                            e
357                        )));
358                    }
359                };
360
361                let url = match resource.url.as_deref() {
362                    Some(url) => url,
363                    None => {
364                        if let Some(reporter) = status_reporter.as_ref() {
365                            let event = DownloadFailed {
366                                resource_name: resource.name.clone(),
367                                dataset_name: None,
368                                output_path: None,
369                                error: "Resource has no URL".to_string(),
370                            };
371                            reporter.on_download_failed(&event);
372                        }
373                        return Err(DataGovError::resource_not_found("Resource has no URL"));
374                    }
375                };
376
377                // Include resource index to prevent filename conflicts
378                let filename =
379                    DataGovClient::get_resource_filename(&resource, None, Some(resource_index));
380                let output_path = output_dir.join(&filename);
381
382                DataGovClient::perform_download(
383                    &http_client,
384                    url,
385                    &output_path,
386                    resource.name.clone(),
387                    None,
388                    status_reporter,
389                )
390                .await?;
391
392                Ok(output_path)
393            };
394
395            futures.push(future);
396        }
397
398        futures::future::join_all(futures).await
399    }
400
401    fn reporter(&self) -> Option<Arc<dyn StatusReporter + Send + Sync>> {
402        self.config.status_reporter.clone()
403    }
404
405    async fn perform_download(
406        http_client: &reqwest::Client,
407        url: &str,
408        output_path: &Path,
409        resource_name: Option<String>,
410        dataset_name: Option<String>,
411        status_reporter: Option<Arc<dyn StatusReporter + Send + Sync>>,
412    ) -> Result<()> {
413        let notify_failure =
414            |message: String, status_reporter: &Option<Arc<dyn StatusReporter + Send + Sync>>| {
415                if let Some(reporter) = status_reporter.as_ref() {
416                    let event = DownloadFailed {
417                        resource_name: resource_name.clone(),
418                        dataset_name: dataset_name.clone(),
419                        output_path: Some(output_path.to_path_buf()),
420                        error: message.clone(),
421                    };
422                    reporter.on_download_failed(&event);
423                }
424            };
425
426        if let Some(parent) = output_path.parent()
427            && let Err(err) = tokio::fs::create_dir_all(parent).await
428        {
429            notify_failure(err.to_string(), &status_reporter);
430            return Err(err.into());
431        }
432
433        let response = match http_client.get(url).send().await {
434            Ok(resp) => resp,
435            Err(err) => {
436                notify_failure(err.to_string(), &status_reporter);
437                return Err(err.into());
438            }
439        };
440
441        if !response.status().is_success() {
442            let message = format!("HTTP {} while downloading {}", response.status(), url);
443            notify_failure(message.clone(), &status_reporter);
444            return Err(DataGovError::download_error(message));
445        }
446
447        let total_size = response.content_length();
448
449        if let Some(reporter) = status_reporter.as_ref() {
450            let event = DownloadStarted {
451                resource_name: resource_name.clone(),
452                dataset_name: dataset_name.clone(),
453                url: url.to_string(),
454                output_path: output_path.to_path_buf(),
455                total_bytes: total_size,
456            };
457            reporter.on_download_started(&event);
458        }
459
460        let mut file = match File::create(output_path).await {
461            Ok(file) => file,
462            Err(err) => {
463                notify_failure(err.to_string(), &status_reporter);
464                return Err(err.into());
465            }
466        };
467
468        let mut stream = response.bytes_stream();
469        let mut downloaded = 0u64;
470
471        while let Some(chunk_result) = stream.next().await {
472            let chunk = match chunk_result {
473                Ok(chunk) => chunk,
474                Err(err) => {
475                    notify_failure(err.to_string(), &status_reporter);
476                    return Err(err.into());
477                }
478            };
479
480            if let Err(err) = file.write_all(&chunk).await {
481                notify_failure(err.to_string(), &status_reporter);
482                return Err(err.into());
483            }
484
485            downloaded += chunk.len() as u64;
486
487            if let Some(reporter) = status_reporter.as_ref() {
488                let event = DownloadProgress {
489                    resource_name: resource_name.clone(),
490                    dataset_name: dataset_name.clone(),
491                    output_path: output_path.to_path_buf(),
492                    downloaded_bytes: downloaded,
493                    total_bytes: total_size,
494                };
495                reporter.on_download_progress(&event);
496            }
497        }
498
499        if let Some(reporter) = status_reporter.as_ref() {
500            let event = DownloadFinished {
501                resource_name,
502                dataset_name,
503                output_path: output_path.to_path_buf(),
504            };
505            reporter.on_download_finished(&event);
506        }
507
508        Ok(())
509    }
510
511    /// Check if the base download directory exists and is writable
512    pub async fn validate_download_dir(&self) -> Result<()> {
513        let base_dir = self.config.get_base_download_dir();
514
515        if !base_dir.exists() {
516            tokio::fs::create_dir_all(&base_dir).await?;
517        }
518
519        if !base_dir.is_dir() {
520            return Err(DataGovError::config_error(format!(
521                "Download path is not a directory: {:?}",
522                base_dir
523            )));
524        }
525
526        let test_file = base_dir.join(".write_test");
527        tokio::fs::write(&test_file, b"test").await?;
528        tokio::fs::remove_file(&test_file).await?;
529
530        Ok(())
531    }
532
533    /// Get the current base download directory
534    pub fn download_dir(&self) -> PathBuf {
535        self.config.get_base_download_dir()
536    }
537
538    /// Get the underlying CKAN client for advanced operations
539    pub fn ckan_client(&self) -> &CkanClient {
540        &self.ckan
541    }
542}
543
544impl Default for DataGovClient {
545    fn default() -> Self {
546        Self::new().expect("Failed to create default DataGovClient")
547    }
548}
549
550#[cfg(test)]
551mod tests {
552    use super::*;
553
554    #[test]
555    fn test_get_resource_filename_no_index() {
556        let resource = Resource {
557            name: Some("data".to_string()),
558            format: Some("CSV".to_string()),
559            url: Some("https://example.com/data.csv".to_string()),
560            ..Default::default()
561        };
562        let filename = DataGovClient::get_resource_filename(&resource, None, None);
563        assert_eq!(filename, "data.csv");
564    }
565
566    #[test]
567    fn test_get_resource_filename_with_index() {
568        let resource = Resource {
569            name: Some("data".to_string()),
570            format: Some("CSV".to_string()),
571            url: Some("https://example.com/data.csv".to_string()),
572            ..Default::default()
573        };
574
575        let filename0 = DataGovClient::get_resource_filename(&resource, None, Some(0));
576        assert_eq!(filename0, "data-0.csv");
577
578        let filename1 = DataGovClient::get_resource_filename(&resource, None, Some(1));
579        assert_eq!(filename1, "data-1.csv");
580
581        let filename2 = DataGovClient::get_resource_filename(&resource, None, Some(2));
582        assert_eq!(filename2, "data-2.csv");
583    }
584
585    #[test]
586    fn test_get_resource_filename_already_has_extension() {
587        let resource = Resource {
588            name: Some("report.csv".to_string()),
589            format: Some("CSV".to_string()),
590            url: Some("https://example.com/report.csv".to_string()),
591            ..Default::default()
592        };
593
594        let filename = DataGovClient::get_resource_filename(&resource, None, Some(3));
595        assert_eq!(filename, "report-3.csv");
596    }
597
598    #[test]
599    fn test_get_resource_filename_no_format() {
600        let resource = Resource {
601            name: Some("myfile".to_string()),
602            format: None,
603            url: Some("https://example.com/myfile".to_string()),
604            ..Default::default()
605        };
606
607        let filename = DataGovClient::get_resource_filename(&resource, None, Some(5));
608        assert_eq!(filename, "myfile-5");
609    }
610
611    #[test]
612    fn test_get_resource_filename_multiple_extensions() {
613        let resource = Resource {
614            name: Some("archive.tar.gz".to_string()),
615            format: Some("TAR".to_string()),
616            url: Some("https://example.com/archive.tar.gz".to_string()),
617            ..Default::default()
618        };
619
620        // Name doesn't end with .tar (it ends with .gz), so format is appended -> archive.tar.gz.tar
621        // Then index is inserted before last dot -> archive.tar.gz-7.tar
622        let filename = DataGovClient::get_resource_filename(&resource, None, Some(7));
623        assert_eq!(filename, "archive.tar.gz-7.tar");
624    }
625}