1use futures::StreamExt;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use tokio::fs::File;
5use tokio::io::AsyncWriteExt;
6use url::Url;
7
8use crate::config::DataGovConfig;
9use crate::error::{DataGovError, Result};
10use crate::ui::{
11 DownloadBatch, DownloadFailed, DownloadFinished, DownloadProgress, DownloadStarted,
12 StatusReporter,
13};
14use data_gov_ckan::{
15 CkanClient,
16 models::{Package, PackageSearchResult, Resource},
17};
18
19#[derive(Debug)]
26pub struct DataGovClient {
27 ckan: CkanClient,
28 config: DataGovConfig,
29 http_client: reqwest::Client,
30}
31
32impl DataGovClient {
33 pub fn new() -> Result<Self> {
35 Self::with_config(DataGovConfig::new())
36 }
37
38 pub fn with_config(config: DataGovConfig) -> Result<Self> {
40 let ckan = CkanClient::new(config.ckan_config.clone());
41
42 let http_client = reqwest::Client::builder()
44 .timeout(std::time::Duration::from_secs(config.download_timeout_secs))
45 .user_agent(&config.user_agent)
46 .build()?;
47
48 Ok(Self {
49 ckan,
50 config,
51 http_client,
52 })
53 }
54
55 pub async fn search(
96 &self,
97 query: &str,
98 limit: Option<i32>,
99 offset: Option<i32>,
100 organization: Option<&str>,
101 format: Option<&str>,
102 ) -> Result<PackageSearchResult> {
103 let fq = match (organization, format) {
106 (Some(org), Some(fmt)) => Some(format!(
107 r#"organization:"{}" AND res_format:"{}""#,
108 org, fmt
109 )),
110 (Some(org), None) => Some(format!(r#"organization:"{}""#, org)),
111 (None, Some(fmt)) => Some(format!(r#"res_format:"{}""#, fmt)),
112 (None, None) => None,
113 };
114
115 let query_param = if query.is_empty() { None } else { Some(query) };
117
118 let result = self
119 .ckan
120 .package_search(query_param, limit, offset, fq.as_deref())
121 .await?;
122
123 Ok(result)
124 }
125
126 pub async fn get_dataset(&self, dataset_id: &str) -> Result<Package> {
128 let package = self.ckan.package_show(dataset_id).await?;
129 Ok(package)
130 }
131
132 pub async fn autocomplete_datasets(
134 &self,
135 partial: &str,
136 limit: Option<i32>,
137 ) -> Result<Vec<String>> {
138 let suggestions = self.ckan.dataset_autocomplete(Some(partial), limit).await?;
139 Ok(suggestions.into_iter().filter_map(|s| s.name).collect())
140 }
141
142 pub async fn list_organizations(&self, limit: Option<i32>) -> Result<Vec<String>> {
144 let orgs = self.ckan.organization_list(None, limit, None).await?;
145 Ok(orgs)
146 }
147
148 pub async fn autocomplete_organizations(
150 &self,
151 partial: &str,
152 limit: Option<i32>,
153 ) -> Result<Vec<String>> {
154 let suggestions = self
155 .ckan
156 .organization_autocomplete(Some(partial), limit)
157 .await?;
158 Ok(suggestions.into_iter().filter_map(|s| s.name).collect())
159 }
160
161 pub fn get_downloadable_resources(package: &Package) -> Vec<Resource> {
168 package
169 .resources
170 .as_ref()
171 .unwrap_or(&Vec::new())
172 .iter()
173 .filter(|resource| {
174 resource.url.is_some()
176 && resource.url_type.as_deref() != Some("api")
177 && resource.format.is_some()
178 })
179 .cloned()
180 .collect()
181 }
182
183 pub fn get_resource_filename(
194 resource: &Resource,
195 fallback_name: Option<&str>,
196 resource_index: Option<usize>,
197 ) -> String {
198 let (base_filename, has_extension) = if let Some(name) = &resource.name {
200 if let Some(format) = &resource.format {
201 let format_lower = format.to_lowercase();
202 if name.ends_with(&format!(".{}", format_lower)) {
203 (name.clone(), true)
204 } else {
205 (format!("{}.{}", name, format_lower), true)
206 }
207 } else {
208 (name.clone(), false)
209 }
210 } else if let Some(url) = &resource.url
211 && let Ok(parsed_url) = Url::parse(url)
212 && let Some(mut segments) = parsed_url.path_segments()
213 && let Some(filename) = segments.next_back()
214 && !filename.is_empty()
215 && filename.contains('.')
216 {
217 (filename.to_string(), true)
218 } else {
219 let base_name = fallback_name.unwrap_or("data");
221 if let Some(format) = &resource.format {
222 (format!("{}.{}", base_name, format.to_lowercase()), true)
223 } else {
224 (format!("{}.dat", base_name), true)
225 }
226 };
227
228 if let Some(index) = resource_index {
230 if has_extension {
231 if let Some(dot_pos) = base_filename.rfind('.') {
233 let (name, ext) = base_filename.split_at(dot_pos);
234 return format!("{}-{}{}", name, index, ext);
235 }
236 }
237 format!("{}-{}", base_filename, index)
239 } else {
240 base_filename
241 }
242 }
243
244 pub async fn download_resource(
254 &self,
255 resource: &Resource,
256 output_dir: Option<&Path>,
257 ) -> Result<PathBuf> {
258 let url = match resource.url.as_deref() {
259 Some(url) => url,
260 None => {
261 if let Some(reporter) = self.config.status_reporter.as_ref() {
262 let event = DownloadFailed {
263 resource_name: resource.name.clone(),
264 dataset_name: None,
265 output_path: None,
266 error: "Resource has no URL".to_string(),
267 };
268 reporter.on_download_failed(&event);
269 }
270 return Err(DataGovError::resource_not_found("Resource has no URL"));
271 }
272 };
273
274 let output_dir = output_dir
275 .map(|p| p.to_path_buf())
276 .unwrap_or_else(|| self.config.get_base_download_dir());
277 let filename = Self::get_resource_filename(resource, None, None);
279 let output_path = output_dir.join(filename);
280
281 Self::perform_download(
282 &self.http_client,
283 url,
284 &output_path,
285 resource.name.clone(),
286 None,
287 self.reporter(),
288 )
289 .await?;
290
291 Ok(output_path)
292 }
293
294 pub async fn download_resources(
302 &self,
303 resources: &[Resource],
304 output_dir: Option<&Path>,
305 ) -> Vec<Result<PathBuf>> {
306 if resources.is_empty() {
307 return vec![];
308 }
309
310 if resources.len() == 1 {
311 return vec![self.download_resource(&resources[0], output_dir).await];
312 }
313
314 if let Some(reporter) = self.config.status_reporter.as_ref() {
316 let event = DownloadBatch {
317 resource_count: resources.len(),
318 dataset_name: None,
319 };
320 reporter.on_download_batch(&event);
321 }
322
323 let output_dir = output_dir
324 .map(|p| p.to_path_buf())
325 .unwrap_or_else(|| self.config.get_base_download_dir());
326
327 let semaphore = Arc::new(tokio::sync::Semaphore::new(
328 self.config.max_concurrent_downloads,
329 ));
330
331 let status_reporter = self.reporter();
332 let mut futures = Vec::with_capacity(resources.len());
333
334 for (resource_index, resource) in resources.iter().enumerate() {
335 let resource = resource.clone();
336 let output_dir = output_dir.clone();
337 let semaphore = semaphore.clone();
338 let http_client = self.http_client.clone();
339 let status_reporter = status_reporter.clone();
340
341 let future = async move {
342 let _permit = match semaphore.acquire().await {
343 Ok(permit) => permit,
344 Err(e) => {
345 if let Some(reporter) = status_reporter.as_ref() {
346 let event = DownloadFailed {
347 resource_name: resource.name.clone(),
348 dataset_name: None,
349 output_path: None,
350 error: format!("Failed to acquire download slot: {}", e),
351 };
352 reporter.on_download_failed(&event);
353 }
354 return Err(DataGovError::download_error(format!(
355 "Semaphore error: {}",
356 e
357 )));
358 }
359 };
360
361 let url = match resource.url.as_deref() {
362 Some(url) => url,
363 None => {
364 if let Some(reporter) = status_reporter.as_ref() {
365 let event = DownloadFailed {
366 resource_name: resource.name.clone(),
367 dataset_name: None,
368 output_path: None,
369 error: "Resource has no URL".to_string(),
370 };
371 reporter.on_download_failed(&event);
372 }
373 return Err(DataGovError::resource_not_found("Resource has no URL"));
374 }
375 };
376
377 let filename =
379 DataGovClient::get_resource_filename(&resource, None, Some(resource_index));
380 let output_path = output_dir.join(&filename);
381
382 DataGovClient::perform_download(
383 &http_client,
384 url,
385 &output_path,
386 resource.name.clone(),
387 None,
388 status_reporter,
389 )
390 .await?;
391
392 Ok(output_path)
393 };
394
395 futures.push(future);
396 }
397
398 futures::future::join_all(futures).await
399 }
400
401 fn reporter(&self) -> Option<Arc<dyn StatusReporter + Send + Sync>> {
402 self.config.status_reporter.clone()
403 }
404
405 async fn perform_download(
406 http_client: &reqwest::Client,
407 url: &str,
408 output_path: &Path,
409 resource_name: Option<String>,
410 dataset_name: Option<String>,
411 status_reporter: Option<Arc<dyn StatusReporter + Send + Sync>>,
412 ) -> Result<()> {
413 let notify_failure =
414 |message: String, status_reporter: &Option<Arc<dyn StatusReporter + Send + Sync>>| {
415 if let Some(reporter) = status_reporter.as_ref() {
416 let event = DownloadFailed {
417 resource_name: resource_name.clone(),
418 dataset_name: dataset_name.clone(),
419 output_path: Some(output_path.to_path_buf()),
420 error: message.clone(),
421 };
422 reporter.on_download_failed(&event);
423 }
424 };
425
426 if let Some(parent) = output_path.parent()
427 && let Err(err) = tokio::fs::create_dir_all(parent).await
428 {
429 notify_failure(err.to_string(), &status_reporter);
430 return Err(err.into());
431 }
432
433 let response = match http_client.get(url).send().await {
434 Ok(resp) => resp,
435 Err(err) => {
436 notify_failure(err.to_string(), &status_reporter);
437 return Err(err.into());
438 }
439 };
440
441 if !response.status().is_success() {
442 let message = format!("HTTP {} while downloading {}", response.status(), url);
443 notify_failure(message.clone(), &status_reporter);
444 return Err(DataGovError::download_error(message));
445 }
446
447 let total_size = response.content_length();
448
449 if let Some(reporter) = status_reporter.as_ref() {
450 let event = DownloadStarted {
451 resource_name: resource_name.clone(),
452 dataset_name: dataset_name.clone(),
453 url: url.to_string(),
454 output_path: output_path.to_path_buf(),
455 total_bytes: total_size,
456 };
457 reporter.on_download_started(&event);
458 }
459
460 let mut file = match File::create(output_path).await {
461 Ok(file) => file,
462 Err(err) => {
463 notify_failure(err.to_string(), &status_reporter);
464 return Err(err.into());
465 }
466 };
467
468 let mut stream = response.bytes_stream();
469 let mut downloaded = 0u64;
470
471 while let Some(chunk_result) = stream.next().await {
472 let chunk = match chunk_result {
473 Ok(chunk) => chunk,
474 Err(err) => {
475 notify_failure(err.to_string(), &status_reporter);
476 return Err(err.into());
477 }
478 };
479
480 if let Err(err) = file.write_all(&chunk).await {
481 notify_failure(err.to_string(), &status_reporter);
482 return Err(err.into());
483 }
484
485 downloaded += chunk.len() as u64;
486
487 if let Some(reporter) = status_reporter.as_ref() {
488 let event = DownloadProgress {
489 resource_name: resource_name.clone(),
490 dataset_name: dataset_name.clone(),
491 output_path: output_path.to_path_buf(),
492 downloaded_bytes: downloaded,
493 total_bytes: total_size,
494 };
495 reporter.on_download_progress(&event);
496 }
497 }
498
499 if let Some(reporter) = status_reporter.as_ref() {
500 let event = DownloadFinished {
501 resource_name,
502 dataset_name,
503 output_path: output_path.to_path_buf(),
504 };
505 reporter.on_download_finished(&event);
506 }
507
508 Ok(())
509 }
510
511 pub async fn validate_download_dir(&self) -> Result<()> {
513 let base_dir = self.config.get_base_download_dir();
514
515 if !base_dir.exists() {
516 tokio::fs::create_dir_all(&base_dir).await?;
517 }
518
519 if !base_dir.is_dir() {
520 return Err(DataGovError::config_error(format!(
521 "Download path is not a directory: {:?}",
522 base_dir
523 )));
524 }
525
526 let test_file = base_dir.join(".write_test");
527 tokio::fs::write(&test_file, b"test").await?;
528 tokio::fs::remove_file(&test_file).await?;
529
530 Ok(())
531 }
532
533 pub fn download_dir(&self) -> PathBuf {
535 self.config.get_base_download_dir()
536 }
537
538 pub fn ckan_client(&self) -> &CkanClient {
540 &self.ckan
541 }
542}
543
544impl Default for DataGovClient {
545 fn default() -> Self {
546 Self::new().expect("Failed to create default DataGovClient")
547 }
548}
549
550#[cfg(test)]
551mod tests {
552 use super::*;
553
554 #[test]
555 fn test_get_resource_filename_no_index() {
556 let resource = Resource {
557 name: Some("data".to_string()),
558 format: Some("CSV".to_string()),
559 url: Some("https://example.com/data.csv".to_string()),
560 ..Default::default()
561 };
562 let filename = DataGovClient::get_resource_filename(&resource, None, None);
563 assert_eq!(filename, "data.csv");
564 }
565
566 #[test]
567 fn test_get_resource_filename_with_index() {
568 let resource = Resource {
569 name: Some("data".to_string()),
570 format: Some("CSV".to_string()),
571 url: Some("https://example.com/data.csv".to_string()),
572 ..Default::default()
573 };
574
575 let filename0 = DataGovClient::get_resource_filename(&resource, None, Some(0));
576 assert_eq!(filename0, "data-0.csv");
577
578 let filename1 = DataGovClient::get_resource_filename(&resource, None, Some(1));
579 assert_eq!(filename1, "data-1.csv");
580
581 let filename2 = DataGovClient::get_resource_filename(&resource, None, Some(2));
582 assert_eq!(filename2, "data-2.csv");
583 }
584
585 #[test]
586 fn test_get_resource_filename_already_has_extension() {
587 let resource = Resource {
588 name: Some("report.csv".to_string()),
589 format: Some("CSV".to_string()),
590 url: Some("https://example.com/report.csv".to_string()),
591 ..Default::default()
592 };
593
594 let filename = DataGovClient::get_resource_filename(&resource, None, Some(3));
595 assert_eq!(filename, "report-3.csv");
596 }
597
598 #[test]
599 fn test_get_resource_filename_no_format() {
600 let resource = Resource {
601 name: Some("myfile".to_string()),
602 format: None,
603 url: Some("https://example.com/myfile".to_string()),
604 ..Default::default()
605 };
606
607 let filename = DataGovClient::get_resource_filename(&resource, None, Some(5));
608 assert_eq!(filename, "myfile-5");
609 }
610
611 #[test]
612 fn test_get_resource_filename_multiple_extensions() {
613 let resource = Resource {
614 name: Some("archive.tar.gz".to_string()),
615 format: Some("TAR".to_string()),
616 url: Some("https://example.com/archive.tar.gz".to_string()),
617 ..Default::default()
618 };
619
620 let filename = DataGovClient::get_resource_filename(&resource, None, Some(7));
623 assert_eq!(filename, "archive.tar.gz-7.tar");
624 }
625}