data_gov/
client.rs

1use futures::StreamExt;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use tokio::fs::File;
5use tokio::io::AsyncWriteExt;
6use url::Url;
7
8use crate::config::DataGovConfig;
9use crate::error::{DataGovError, Result};
10use crate::ui::{
11    DownloadBatch, DownloadFailed, DownloadFinished, DownloadProgress, DownloadStarted,
12    StatusReporter,
13};
14use data_gov_ckan::{
15    CkanClient,
16    models::{Package, PackageSearchResult, Resource},
17};
18
19/// Async client for exploring data.gov datasets.
20///
21/// `DataGovClient` layers ergonomic helpers on top of the lower-level
22/// [`data_gov_ckan::CkanClient`]. In addition to search and metadata lookups it
23/// handles download destinations, progress reporting, and colour-aware output
24/// that matches the `data-gov` CLI defaults.
25#[derive(Debug)]
26pub struct DataGovClient {
27    ckan: CkanClient,
28    config: DataGovConfig,
29    http_client: reqwest::Client,
30}
31
32impl DataGovClient {
33    /// Create a new DataGov client with default configuration
34    pub fn new() -> Result<Self> {
35        Self::with_config(DataGovConfig::new())
36    }
37
38    /// Create a new DataGov client with custom configuration
39    pub fn with_config(config: DataGovConfig) -> Result<Self> {
40        let ckan = CkanClient::new(config.ckan_config.clone());
41
42        // Create HTTP client with timeout for downloads
43        let http_client = reqwest::Client::builder()
44            .timeout(std::time::Duration::from_secs(config.download_timeout_secs))
45            .user_agent(&config.user_agent)
46            .build()?;
47
48        Ok(Self {
49            ckan,
50            config,
51            http_client,
52        })
53    }
54
55    // === Search and Discovery ===
56
57    /// Search for datasets on data.gov.
58    ///
59    /// # Arguments
60    /// * `query` - Search terms (searches titles, descriptions, tags)
61    /// * `limit` - Maximum number of results (default: 10, max: 1000)
62    /// * `offset` - Number of results to skip for pagination (default: 0)
63    /// * `organization` - Filter by organization name (optional)
64    /// * `format` - Filter by resource format (optional, e.g., "CSV", "JSON")
65    ///
66    /// # Examples
67    ///
68    /// Basic search:
69    /// ```rust,no_run
70    /// # use data_gov::DataGovClient;
71    /// # #[tokio::main]
72    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
73    /// let client = DataGovClient::new()?;
74    /// let results = client.search("climate data", Some(20), None, None, None).await?;
75    /// # Ok(())
76    /// # }
77    /// ```
78    ///
79    /// Search with filters:
80    /// ```rust,no_run
81    /// # use data_gov::DataGovClient;
82    /// # #[tokio::main]
83    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
84    /// # let client = DataGovClient::new()?;
85    /// let results = client.search("energy", Some(10), None, Some("doe-gov"), Some("CSV")).await?;
86    /// # Ok(())
87    /// # }
88    /// ```
89    pub async fn search(
90        &self,
91        query: &str,
92        limit: Option<i32>,
93        offset: Option<i32>,
94        organization: Option<&str>,
95        format: Option<&str>,
96    ) -> Result<PackageSearchResult> {
97        // Build filter query for advanced filtering
98        let fq = match (organization, format) {
99            (Some(org), Some(fmt)) => Some(format!(
100                r#"organization:"{}" AND res_format:"{}""#,
101                org, fmt
102            )),
103            (Some(org), None) => Some(format!(r#"organization:"{}""#, org)),
104            (None, Some(fmt)) => Some(format!(r#"res_format:"{}""#, fmt)),
105            (None, None) => None,
106        };
107
108        let result = self
109            .ckan
110            .package_search(Some(query), limit, offset, fq.as_deref())
111            .await?;
112
113        Ok(result)
114    }
115
116    /// Fetch the full `package_show` payload for a dataset.
117    pub async fn get_dataset(&self, dataset_id: &str) -> Result<Package> {
118        let package = self.ckan.package_show(dataset_id).await?;
119        Ok(package)
120    }
121
122    /// Fetch dataset name suggestions for interactive prompts.
123    pub async fn autocomplete_datasets(
124        &self,
125        partial: &str,
126        limit: Option<i32>,
127    ) -> Result<Vec<String>> {
128        let suggestions = self.ckan.dataset_autocomplete(Some(partial), limit).await?;
129        Ok(suggestions.into_iter().filter_map(|s| s.name).collect())
130    }
131
132    /// List the publisher slugs for government organizations.
133    pub async fn list_organizations(&self, limit: Option<i32>) -> Result<Vec<String>> {
134        let orgs = self.ckan.organization_list(None, limit, None).await?;
135        Ok(orgs)
136    }
137
138    /// Fetch organization name suggestions for interactive prompts.
139    pub async fn autocomplete_organizations(
140        &self,
141        partial: &str,
142        limit: Option<i32>,
143    ) -> Result<Vec<String>> {
144        let suggestions = self
145            .ckan
146            .organization_autocomplete(Some(partial), limit)
147            .await?;
148        Ok(suggestions.into_iter().filter_map(|s| s.name).collect())
149    }
150
151    // === Resource Management ===
152
153    /// Return resources that look like downloadable files.
154    ///
155    /// The returned list is filtered to resources that expose a direct URL, are
156    /// not marked as API endpoints, and advertise a file format.
157    pub fn get_downloadable_resources(package: &Package) -> Vec<Resource> {
158        package
159            .resources
160            .as_ref()
161            .unwrap_or(&Vec::new())
162            .iter()
163            .filter(|resource| {
164                // Has a URL and is not an API endpoint
165                resource.url.is_some()
166                    && resource.url_type.as_deref() != Some("api")
167                    && resource.format.is_some()
168            })
169            .cloned()
170            .collect()
171    }
172
173    /// Pick a filesystem-friendly filename for a resource download.
174    /// Pick a filesystem-friendly filename for a resource download.
175    ///
176    /// # Arguments
177    /// * `resource` - The resource to generate a filename for
178    /// * `fallback_name` - Optional fallback name if resource has no name
179    /// * `resource_index` - Optional index to append to prevent conflicts when multiple resources have the same name
180    ///
181    /// When downloading multiple resources, the index should be provided to ensure unique filenames
182    /// even when resources have duplicate names.
183    pub fn get_resource_filename(
184        resource: &Resource,
185        fallback_name: Option<&str>,
186        resource_index: Option<usize>,
187    ) -> String {
188        // Determine base filename and whether it has an extension
189        let (base_filename, has_extension) = if let Some(name) = &resource.name {
190            if let Some(format) = &resource.format {
191                let format_lower = format.to_lowercase();
192                if name.ends_with(&format!(".{}", format_lower)) {
193                    (name.clone(), true)
194                } else {
195                    (format!("{}.{}", name, format_lower), true)
196                }
197            } else {
198                (name.clone(), false)
199            }
200        } else if let Some(url) = &resource.url
201            && let Ok(parsed_url) = Url::parse(url)
202            && let Some(mut segments) = parsed_url.path_segments()
203            && let Some(filename) = segments.next_back()
204            && !filename.is_empty()
205            && filename.contains('.')
206        {
207            (filename.to_string(), true)
208        } else {
209            // Use fallback with format extension
210            let base_name = fallback_name.unwrap_or("data");
211            if let Some(format) = &resource.format {
212                (format!("{}.{}", base_name, format.to_lowercase()), true)
213            } else {
214                (format!("{}.dat", base_name), true)
215            }
216        };
217
218        // Append resource index if provided (to handle duplicate names)
219        if let Some(index) = resource_index {
220            if has_extension {
221                // Insert index before the extension
222                if let Some(dot_pos) = base_filename.rfind('.') {
223                    let (name, ext) = base_filename.split_at(dot_pos);
224                    return format!("{}-{}{}", name, index, ext);
225                }
226            }
227            // No extension or couldn't find dot, just append
228            format!("{}-{}", base_filename, index)
229        } else {
230            base_filename
231        }
232    }
233
234    // === File Downloads ===
235
236    /// Download a single resource to the specified directory.
237    ///
238    /// # Arguments
239    /// * `resource` - The resource to download
240    /// * `output_dir` - Directory where the file will be saved. If None, uses the base download directory.
241    ///
242    /// Returns the full path where the file was saved.
243    pub async fn download_resource(
244        &self,
245        resource: &Resource,
246        output_dir: Option<&Path>,
247    ) -> Result<PathBuf> {
248        let url = match resource.url.as_deref() {
249            Some(url) => url,
250            None => {
251                if let Some(reporter) = self.config.status_reporter.as_ref() {
252                    let event = DownloadFailed {
253                        resource_name: resource.name.clone(),
254                        dataset_name: None,
255                        output_path: None,
256                        error: "Resource has no URL".to_string(),
257                    };
258                    reporter.on_download_failed(&event);
259                }
260                return Err(DataGovError::resource_not_found("Resource has no URL"));
261            }
262        };
263
264        let output_dir = output_dir
265            .map(|p| p.to_path_buf())
266            .unwrap_or_else(|| self.config.get_base_download_dir());
267        // No resource index needed for single download
268        let filename = Self::get_resource_filename(resource, None, None);
269        let output_path = output_dir.join(filename);
270
271        Self::perform_download(
272            &self.http_client,
273            url,
274            &output_path,
275            resource.name.clone(),
276            None,
277            self.reporter(),
278        )
279        .await?;
280
281        Ok(output_path)
282    }
283
284    /// Download multiple resources concurrently to the specified directory.
285    ///
286    /// # Arguments
287    /// * `resources` - Slice of resources to download
288    /// * `output_dir` - Directory where files will be saved. If None, uses the base download directory.
289    ///
290    /// Returns one [`Result`] per resource so callers can inspect partial failures.
291    pub async fn download_resources(
292        &self,
293        resources: &[Resource],
294        output_dir: Option<&Path>,
295    ) -> Vec<Result<PathBuf>> {
296        if resources.is_empty() {
297            return vec![];
298        }
299
300        if resources.len() == 1 {
301            return vec![self.download_resource(&resources[0], output_dir).await];
302        }
303
304        // Multiple resources - use concurrent downloads with semaphore
305        if let Some(reporter) = self.config.status_reporter.as_ref() {
306            let event = DownloadBatch {
307                resource_count: resources.len(),
308                dataset_name: None,
309            };
310            reporter.on_download_batch(&event);
311        }
312
313        let output_dir = output_dir
314            .map(|p| p.to_path_buf())
315            .unwrap_or_else(|| self.config.get_base_download_dir());
316
317        let semaphore = Arc::new(tokio::sync::Semaphore::new(
318            self.config.max_concurrent_downloads,
319        ));
320
321        let status_reporter = self.reporter();
322        let mut futures = Vec::with_capacity(resources.len());
323
324        for (resource_index, resource) in resources.iter().enumerate() {
325            let resource = resource.clone();
326            let output_dir = output_dir.clone();
327            let semaphore = semaphore.clone();
328            let http_client = self.http_client.clone();
329            let status_reporter = status_reporter.clone();
330
331            let future = async move {
332                let _permit = match semaphore.acquire().await {
333                    Ok(permit) => permit,
334                    Err(e) => {
335                        if let Some(reporter) = status_reporter.as_ref() {
336                            let event = DownloadFailed {
337                                resource_name: resource.name.clone(),
338                                dataset_name: None,
339                                output_path: None,
340                                error: format!("Failed to acquire download slot: {}", e),
341                            };
342                            reporter.on_download_failed(&event);
343                        }
344                        return Err(DataGovError::download_error(format!(
345                            "Semaphore error: {}",
346                            e
347                        )));
348                    }
349                };
350
351                let url = match resource.url.as_deref() {
352                    Some(url) => url,
353                    None => {
354                        if let Some(reporter) = status_reporter.as_ref() {
355                            let event = DownloadFailed {
356                                resource_name: resource.name.clone(),
357                                dataset_name: None,
358                                output_path: None,
359                                error: "Resource has no URL".to_string(),
360                            };
361                            reporter.on_download_failed(&event);
362                        }
363                        return Err(DataGovError::resource_not_found("Resource has no URL"));
364                    }
365                };
366
367                // Include resource index to prevent filename conflicts
368                let filename =
369                    DataGovClient::get_resource_filename(&resource, None, Some(resource_index));
370                let output_path = output_dir.join(&filename);
371
372                DataGovClient::perform_download(
373                    &http_client,
374                    url,
375                    &output_path,
376                    resource.name.clone(),
377                    None,
378                    status_reporter,
379                )
380                .await?;
381
382                Ok(output_path)
383            };
384
385            futures.push(future);
386        }
387
388        futures::future::join_all(futures).await
389    }
390
391    fn reporter(&self) -> Option<Arc<dyn StatusReporter + Send + Sync>> {
392        self.config.status_reporter.clone()
393    }
394
395    async fn perform_download(
396        http_client: &reqwest::Client,
397        url: &str,
398        output_path: &Path,
399        resource_name: Option<String>,
400        dataset_name: Option<String>,
401        status_reporter: Option<Arc<dyn StatusReporter + Send + Sync>>,
402    ) -> Result<()> {
403        let notify_failure =
404            |message: String, status_reporter: &Option<Arc<dyn StatusReporter + Send + Sync>>| {
405                if let Some(reporter) = status_reporter.as_ref() {
406                    let event = DownloadFailed {
407                        resource_name: resource_name.clone(),
408                        dataset_name: dataset_name.clone(),
409                        output_path: Some(output_path.to_path_buf()),
410                        error: message.clone(),
411                    };
412                    reporter.on_download_failed(&event);
413                }
414            };
415
416        if let Some(parent) = output_path.parent()
417            && let Err(err) = tokio::fs::create_dir_all(parent).await
418        {
419            notify_failure(err.to_string(), &status_reporter);
420            return Err(err.into());
421        }
422
423        let response = match http_client.get(url).send().await {
424            Ok(resp) => resp,
425            Err(err) => {
426                notify_failure(err.to_string(), &status_reporter);
427                return Err(err.into());
428            }
429        };
430
431        if !response.status().is_success() {
432            let message = format!("HTTP {} while downloading {}", response.status(), url);
433            notify_failure(message.clone(), &status_reporter);
434            return Err(DataGovError::download_error(message));
435        }
436
437        let total_size = response.content_length();
438
439        if let Some(reporter) = status_reporter.as_ref() {
440            let event = DownloadStarted {
441                resource_name: resource_name.clone(),
442                dataset_name: dataset_name.clone(),
443                url: url.to_string(),
444                output_path: output_path.to_path_buf(),
445                total_bytes: total_size,
446            };
447            reporter.on_download_started(&event);
448        }
449
450        let mut file = match File::create(output_path).await {
451            Ok(file) => file,
452            Err(err) => {
453                notify_failure(err.to_string(), &status_reporter);
454                return Err(err.into());
455            }
456        };
457
458        let mut stream = response.bytes_stream();
459        let mut downloaded = 0u64;
460
461        while let Some(chunk_result) = stream.next().await {
462            let chunk = match chunk_result {
463                Ok(chunk) => chunk,
464                Err(err) => {
465                    notify_failure(err.to_string(), &status_reporter);
466                    return Err(err.into());
467                }
468            };
469
470            if let Err(err) = file.write_all(&chunk).await {
471                notify_failure(err.to_string(), &status_reporter);
472                return Err(err.into());
473            }
474
475            downloaded += chunk.len() as u64;
476
477            if let Some(reporter) = status_reporter.as_ref() {
478                let event = DownloadProgress {
479                    resource_name: resource_name.clone(),
480                    dataset_name: dataset_name.clone(),
481                    output_path: output_path.to_path_buf(),
482                    downloaded_bytes: downloaded,
483                    total_bytes: total_size,
484                };
485                reporter.on_download_progress(&event);
486            }
487        }
488
489        if let Some(reporter) = status_reporter.as_ref() {
490            let event = DownloadFinished {
491                resource_name,
492                dataset_name,
493                output_path: output_path.to_path_buf(),
494            };
495            reporter.on_download_finished(&event);
496        }
497
498        Ok(())
499    }
500
501    /// Check if the base download directory exists and is writable
502    pub async fn validate_download_dir(&self) -> Result<()> {
503        let base_dir = self.config.get_base_download_dir();
504
505        if !base_dir.exists() {
506            tokio::fs::create_dir_all(&base_dir).await?;
507        }
508
509        if !base_dir.is_dir() {
510            return Err(DataGovError::config_error(format!(
511                "Download path is not a directory: {:?}",
512                base_dir
513            )));
514        }
515
516        let test_file = base_dir.join(".write_test");
517        tokio::fs::write(&test_file, b"test").await?;
518        tokio::fs::remove_file(&test_file).await?;
519
520        Ok(())
521    }
522
523    /// Get the current base download directory
524    pub fn download_dir(&self) -> PathBuf {
525        self.config.get_base_download_dir()
526    }
527
528    /// Get the underlying CKAN client for advanced operations
529    pub fn ckan_client(&self) -> &CkanClient {
530        &self.ckan
531    }
532}
533
534impl Default for DataGovClient {
535    fn default() -> Self {
536        Self::new().expect("Failed to create default DataGovClient")
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543
544    #[test]
545    fn test_get_resource_filename_no_index() {
546        let resource = Resource {
547            name: Some("data".to_string()),
548            format: Some("CSV".to_string()),
549            url: Some("https://example.com/data.csv".to_string()),
550            ..Default::default()
551        };
552        let filename = DataGovClient::get_resource_filename(&resource, None, None);
553        assert_eq!(filename, "data.csv");
554    }
555
556    #[test]
557    fn test_get_resource_filename_with_index() {
558        let resource = Resource {
559            name: Some("data".to_string()),
560            format: Some("CSV".to_string()),
561            url: Some("https://example.com/data.csv".to_string()),
562            ..Default::default()
563        };
564
565        let filename0 = DataGovClient::get_resource_filename(&resource, None, Some(0));
566        assert_eq!(filename0, "data-0.csv");
567
568        let filename1 = DataGovClient::get_resource_filename(&resource, None, Some(1));
569        assert_eq!(filename1, "data-1.csv");
570
571        let filename2 = DataGovClient::get_resource_filename(&resource, None, Some(2));
572        assert_eq!(filename2, "data-2.csv");
573    }
574
575    #[test]
576    fn test_get_resource_filename_already_has_extension() {
577        let resource = Resource {
578            name: Some("report.csv".to_string()),
579            format: Some("CSV".to_string()),
580            url: Some("https://example.com/report.csv".to_string()),
581            ..Default::default()
582        };
583
584        let filename = DataGovClient::get_resource_filename(&resource, None, Some(3));
585        assert_eq!(filename, "report-3.csv");
586    }
587
588    #[test]
589    fn test_get_resource_filename_no_format() {
590        let resource = Resource {
591            name: Some("myfile".to_string()),
592            format: None,
593            url: Some("https://example.com/myfile".to_string()),
594            ..Default::default()
595        };
596
597        let filename = DataGovClient::get_resource_filename(&resource, None, Some(5));
598        assert_eq!(filename, "myfile-5");
599    }
600
601    #[test]
602    fn test_get_resource_filename_multiple_extensions() {
603        let resource = Resource {
604            name: Some("archive.tar.gz".to_string()),
605            format: Some("TAR".to_string()),
606            url: Some("https://example.com/archive.tar.gz".to_string()),
607            ..Default::default()
608        };
609
610        // Name doesn't end with .tar (it ends with .gz), so format is appended -> archive.tar.gz.tar
611        // Then index is inserted before last dot -> archive.tar.gz-7.tar
612        let filename = DataGovClient::get_resource_filename(&resource, None, Some(7));
613        assert_eq!(filename, "archive.tar.gz-7.tar");
614    }
615}