1use std::collections::HashMap;
10use std::io::Read;
11use std::path::Path;
12use std::time::Duration;
13
14use scirs2_core::ndarray::{Array1, Array2};
15use serde::{Deserialize, Serialize};
16
17use crate::cache::DatasetCache;
18use crate::error::{DatasetsError, Result};
19use crate::loaders::{load_csv, CsvConfig};
20use crate::utils::Dataset;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ExternalConfig {
25 pub timeout_seconds: u64,
27 pub max_retries: u32,
29 pub user_agent: String,
31 pub headers: HashMap<String, String>,
33 pub verify_ssl: bool,
35 pub use_cache: bool,
37}
38
39impl Default for ExternalConfig {
40 fn default() -> Self {
41 Self {
42 timeout_seconds: 300, max_retries: 3,
44 user_agent: "scirs2-datasets/0.1.0".to_string(),
45 headers: HashMap::new(),
46 verify_ssl: true,
47 use_cache: true,
48 }
49 }
50}
51
52pub type ProgressCallback = Box<dyn Fn(u64, u64) + Send + Sync>;
54
55pub struct ExternalClient {
57 config: ExternalConfig,
58 cache: DatasetCache,
59 #[cfg(feature = "download")]
60 client: reqwest::Client,
61}
62
63impl ExternalClient {
64 pub fn new() -> Result<Self> {
66 Self::with_config(ExternalConfig::default())
67 }
68
69 pub fn with_config(config: ExternalConfig) -> Result<Self> {
71 let cache = DatasetCache::new(crate::cache::get_cachedir()?);
72
73 #[cfg(feature = "download")]
74 let client = {
75 let mut builder = reqwest::Client::builder()
76 .timeout(Duration::from_secs(config.timeout_seconds))
77 .user_agent(&config.user_agent);
78
79 if !config.verify_ssl {
80 builder = builder.danger_accept_invalid_certs(true);
81 }
82
83 builder
84 .build()
85 .map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?
86 };
87
88 Ok(Self {
89 config,
90 cache,
91 #[cfg(feature = "download")]
92 client,
93 })
94 }
95
96 #[cfg(feature = "download")]
98 pub async fn download_dataset(
99 &self,
100 url: &str,
101 progress: Option<ProgressCallback>,
102 ) -> Result<Dataset> {
103 if self.config.use_cache {
105 let cache_key = format!("external_{}", blake3::hash(url.as_bytes()).to_hex());
106 if let Ok(cached_data) = self.cache.read_cached(&cache_key) {
107 return self.parse_cached_data(&cached_data);
108 }
109 }
110
111 let response = self.make_request(url).await?;
113 let total_size = response.content_length().unwrap_or(0);
114
115 let mut downloaded = 0u64;
116 let mut buffer = Vec::new();
117 let mut stream = response.bytes_stream();
118
119 use futures_util::StreamExt;
120 while let Some(chunk) = stream.next().await {
121 let chunk = chunk.map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?;
122 downloaded += chunk.len() as u64;
123 buffer.extend_from_slice(&chunk);
124
125 if let Some(ref callback) = progress {
126 callback(downloaded, total_size);
127 }
128 }
129
130 if self.config.use_cache {
132 let cache_key = format!("external_{}", blake3::hash(url.as_bytes()).to_hex());
133 let _ = self.cache.put(&cache_key, &buffer);
134 }
135
136 self.parse_downloaded_data(url, &buffer)
138 }
139
140 #[cfg(feature = "download")]
142 pub fn download_dataset_sync(
143 &self,
144 url: &str,
145 progress: Option<ProgressCallback>,
146 ) -> Result<Dataset> {
147 let rt = tokio::runtime::Runtime::new()
149 .map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?;
150 rt.block_on(self.download_dataset(url, progress))
151 }
152
153 #[cfg(not(feature = "download"))]
155 #[cfg(feature = "download-sync")]
156 pub fn download_dataset_sync(
157 &self,
158 url: &str,
159 progress: Option<ProgressCallback>,
160 ) -> Result<Dataset> {
161 self.download_with_ureq(url, progress)
163 }
164
165 #[cfg(not(feature = "download"))]
167 #[cfg(not(feature = "download-sync"))]
168 pub fn download_dataset_sync(
169 &self,
170 _url: &str,
171 _progress: Option<ProgressCallback>,
172 ) -> Result<Dataset> {
173 Err(DatasetsError::FormatError(
174 "Synchronous download feature is disabled. Enable 'download-sync' feature or use async download.".to_string()
175 ))
176 }
177
178 #[cfg(feature = "download-sync")]
180 #[allow(dead_code)]
181 fn download_with_ureq(&self, url: &str, progress: Option<ProgressCallback>) -> Result<Dataset> {
182 if self.config.use_cache {
184 let cache_key = format!("external_{}", blake3::hash(url.as_bytes()).to_hex());
185 if let Ok(cached_data) = self.cache.read_cached(&cache_key) {
186 return self.parse_cached_data(&cached_data);
187 }
188 }
189
190 let mut request = ureq::get(url).header("User-Agent", &self.config.user_agent);
191
192 for (key, value) in &self.config.headers {
194 request = request.header(key, value);
195 }
196
197 let response = request
198 .call()
199 .map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?;
200
201 let headers = response.headers();
203 let total_size = headers
204 .get("Content-Length")
205 .and_then(|hv| hv.to_str().ok())
206 .and_then(|s| s.parse::<u64>().ok())
207 .unwrap_or(0);
208
209 let mut body = response.into_body();
211 let buffer = body
212 .read_to_vec()
213 .map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?;
214 let downloaded = buffer.len() as u64;
215 if let Some(ref callback) = progress {
216 callback(downloaded, total_size);
217 }
218
219 if self.config.use_cache {
221 let cache_key = format!("external_{}", blake3::hash(url.as_bytes()).to_hex());
222 let _ = self.cache.put(&cache_key, &buffer);
223 }
224
225 self.parse_downloaded_data(url, &buffer)
227 }
228
229 #[cfg(feature = "download")]
230 async fn make_request(&self, url: &str) -> Result<reqwest::Response> {
231 let mut request = self.client.get(url);
232
233 for (key, value) in &self.config.headers {
235 request = request.header(key, value);
236 }
237
238 let mut last_error = None;
239
240 for attempt in 0..=self.config.max_retries {
241 match request
242 .try_clone()
243 .ok_or_else(|| {
244 DatasetsError::IoError(std::io::Error::other("Failed to clone request"))
245 })?
246 .send()
247 .await
248 {
249 Ok(response) => {
250 if response.status().is_success() {
251 return Ok(response);
252 } else {
253 last_error = Some(DatasetsError::IoError(std::io::Error::other(format!(
254 "HTTP {}: {}",
255 response.status(),
256 response.status().canonical_reason().unwrap_or("Unknown")
257 ))));
258 }
259 }
260 Err(e) => {
261 last_error = Some(DatasetsError::IoError(std::io::Error::other(e)));
262 }
263 }
264
265 if attempt < self.config.max_retries {
266 tokio::time::sleep(Duration::from_millis(1000 * (attempt + 1) as u64)).await;
267 }
268 }
269
270 Err(last_error.expect("Operation failed"))
271 }
272
273 fn parse_cached_data(&self, data: &[u8]) -> Result<Dataset> {
274 if let Ok(dataset) = serde_json::from_slice::<Dataset>(data) {
276 return Ok(dataset);
277 }
278
279 self.parse_raw_data(data, None)
281 }
282
283 fn parse_downloaded_data(&self, url: &str, data: &[u8]) -> Result<Dataset> {
284 let extension = Path::new(url)
285 .extension()
286 .and_then(|s| s.to_str())
287 .unwrap_or("")
288 .to_lowercase();
289
290 self.parse_raw_data(data, Some(&extension))
291 }
292
293 fn parse_raw_data(&self, data: &[u8], extension: Option<&str>) -> Result<Dataset> {
294 match extension {
295 Some("csv") | None => {
296 let csv_data = String::from_utf8(data.to_vec())
298 .map_err(|e| DatasetsError::FormatError(format!("Invalid UTF-8: {e}")))?;
299
300 let temp_file = tempfile::NamedTempFile::new().map_err(DatasetsError::IoError)?;
302
303 std::fs::write(temp_file.path(), &csv_data).map_err(DatasetsError::IoError)?;
304
305 load_csv(temp_file.path(), CsvConfig::default())
306 }
307 Some("json") => {
308 let json_str = String::from_utf8(data.to_vec())
310 .map_err(|e| DatasetsError::FormatError(format!("Invalid UTF-8: {e}")))?;
311
312 serde_json::from_str(&json_str)
313 .map_err(|e| DatasetsError::FormatError(format!("Invalid JSON: {e}")))
314 }
315 Some("arff") => {
316 self.parse_arff_data(data)
318 }
319 _ => {
320 self.auto_detect_and_parse(data)
322 }
323 }
324 }
325
326 fn parse_arff_data(&self, data: &[u8]) -> Result<Dataset> {
327 let content = String::from_utf8(data.to_vec())
328 .map_err(|e| DatasetsError::FormatError(format!("Invalid UTF-8: {e}")))?;
329
330 let lines = content.lines();
331 let mut attributes = Vec::new();
332 let mut data_section = false;
333 let mut data_lines = Vec::new();
334
335 for line in lines {
336 let line = line.trim();
337
338 if line.is_empty() || line.starts_with('%') {
339 continue;
340 }
341
342 if line.to_lowercase().starts_with("@attribute") {
343 let parts: Vec<&str> = line.split_whitespace().collect();
344 if parts.len() >= 2 {
345 attributes.push(parts[1].to_string());
346 }
347 } else if line.to_lowercase().starts_with("@data") {
348 data_section = true;
349 } else if data_section {
350 data_lines.push(line.to_string());
351 }
352 }
353
354 let mut rows: Vec<Vec<f64>> = Vec::new();
356 for line in data_lines {
357 let values: Result<Vec<f64>> = line
358 .split(',')
359 .map(|s| {
360 s.trim()
361 .parse::<f64>()
362 .map_err(|_| DatasetsError::FormatError(format!("Invalid number: {s}")))
363 })
364 .collect();
365
366 match values {
367 Ok(row) => rows.push(row),
368 Err(_) => continue, }
370 }
371
372 if rows.is_empty() {
373 return Err(DatasetsError::FormatError(
374 "No valid data rows found".to_string(),
375 ));
376 }
377
378 let n_features = rows[0].len();
379 let n_samples = rows.len();
380
381 let (data_cols, target_col) = if n_features > 1 {
383 (n_features - 1, Some(n_features - 1))
384 } else {
385 (n_features, None)
386 };
387
388 let mut data_vec = Vec::with_capacity(n_samples * data_cols);
390 let mut target_vec = if target_col.is_some() {
391 Some(Vec::with_capacity(n_samples))
392 } else {
393 None
394 };
395
396 for row in rows {
397 for (i, &value) in row.iter().enumerate() {
398 if i < data_cols {
399 data_vec.push(value);
400 } else if let Some(ref mut targets) = target_vec {
401 targets.push(value);
402 }
403 }
404 }
405
406 let data = Array2::from_shape_vec((n_samples, data_cols), data_vec)
407 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
408
409 let target = target_vec.map(Array1::from_vec);
410
411 Ok(Dataset {
412 data,
413 target,
414 featurenames: Some(attributes[..data_cols].to_vec()),
415 targetnames: None,
416 feature_descriptions: None,
417 description: Some("ARFF dataset loaded from external source".to_string()),
418 metadata: std::collections::HashMap::new(),
419 })
420 }
421
422 fn auto_detect_and_parse(&self, data: &[u8]) -> Result<Dataset> {
423 let content = String::from_utf8(data.to_vec())
424 .map_err(|e| DatasetsError::FormatError(format!("Invalid UTF-8: {e}")))?;
425
426 if content.trim().starts_with('{') || content.trim().starts_with('[') {
428 if let Ok(dataset) = serde_json::from_str::<Dataset>(&content) {
429 return Ok(dataset);
430 }
431 }
432
433 if content.contains(',') || content.contains('\t') {
435 return self.parse_raw_data(data, Some("csv"));
436 }
437
438 if content.to_lowercase().contains("@relation") {
440 return self.parse_arff_data(data);
441 }
442
443 Err(DatasetsError::FormatError(
444 "Unable to auto-detect data format".to_string(),
445 ))
446 }
447}
448
449pub mod repositories {
451 use super::*;
452
453 pub struct UCIRepository {
455 client: ExternalClient,
456 base_url: String,
457 }
458
459 impl UCIRepository {
460 pub fn new() -> Result<Self> {
462 Ok(Self {
463 client: ExternalClient::new()?,
464 base_url: "https://archive.ics.uci.edu/ml/machine-learning-databases".to_string(),
465 })
466 }
467
468 #[cfg(feature = "download")]
476 pub async fn load_dataset(&self, name: &str) -> Result<Dataset> {
477 let url = match name {
478 "adult" => format!("{}/adult/adult.data", self.base_url),
479 "wine" => format!("{}/wine/wine.data", self.base_url),
480 "glass" => format!("{}/glass/glass.data", self.base_url),
481 "hepatitis" => format!("{}/hepatitis/hepatitis.data", self.base_url),
482 "heart-disease" => {
483 format!("{}/heart-disease/processed.cleveland.data", self.base_url)
484 }
485 _ => {
486 return Err(DatasetsError::NotFound(format!(
487 "UCI dataset '{name}' not found"
488 )))
489 }
490 };
491
492 self.client.download_dataset(&url, None).await
493 }
494
495 #[cfg(not(feature = "download"))]
496 pub fn load_dataset_sync(&self, name: &str) -> Result<Dataset> {
498 let url = match name {
499 "adult" => format!("{}/adult/adult.data", self.base_url),
500 "wine" => format!("{}/wine/wine.data", self.base_url),
501 "glass" => format!("{}/glass/glass.data", self.base_url),
502 "hepatitis" => format!("{}/hepatitis/hepatitis.data", self.base_url),
503 "heart-disease" => {
504 format!("{}/heart-disease/processed.cleveland.data", self.base_url)
505 }
506 _ => {
507 return Err(DatasetsError::NotFound(format!(
508 "UCI dataset '{name}' not found"
509 )))
510 }
511 };
512
513 self.client.download_dataset_sync(&url, None)
514 }
515
516 pub fn list_datasets(&self) -> Vec<&'static str> {
518 vec!["adult", "wine", "glass", "hepatitis", "heart-disease"]
519 }
520 }
521
522 pub struct KaggleRepository {
524 #[allow(dead_code)]
525 client: ExternalClient,
526 #[allow(dead_code)]
527 api_key: Option<String>,
528 }
529
530 impl KaggleRepository {
531 pub fn new(_apikey: Option<String>) -> Result<Self> {
533 let mut config = ExternalConfig::default();
534
535 if let Some(ref key) = _apikey {
536 config
537 .headers
538 .insert("Authorization".to_string(), format!("Bearer {key}"));
539 }
540
541 Ok(Self {
542 client: ExternalClient::with_config(config)?,
543 api_key: _apikey,
544 })
545 }
546
547 #[cfg(feature = "download")]
555 pub async fn load_competition_data(&self, competition: &str) -> Result<Dataset> {
556 if self.api_key.is_none() {
557 return Err(DatasetsError::AuthenticationError(
558 "Kaggle API key required".to_string(),
559 ));
560 }
561
562 let url = format!(
563 "https://www.kaggle.com/api/v1/competitions/{}/data/download",
564 competition
565 );
566 self.client.download_dataset(&url, None).await
567 }
568 }
569
570 pub struct GitHubRepository {
572 client: ExternalClient,
573 }
574
575 impl GitHubRepository {
576 pub fn new() -> Result<Self> {
578 Ok(Self {
579 client: ExternalClient::new()?,
580 })
581 }
582
583 #[cfg(feature = "download")]
593 pub async fn load_from_repo(&self, user: &str, repo: &str, path: &str) -> Result<Dataset> {
594 let url = format!("https://raw.githubusercontent.com/{user}/{repo}/main/{path}");
595 self.client.download_dataset(&url, None).await
596 }
597
598 #[cfg(not(feature = "download"))]
599 pub fn load_from_repo_sync(&self, user: &str, repo: &str, path: &str) -> Result<Dataset> {
601 let url = format!("https://raw.githubusercontent.com/{user}/{repo}/main/{path}");
602 self.client.download_dataset_sync(&url, None)
603 }
604 }
605}
606
607pub mod convenience {
609 use super::repositories::*;
610 use super::*;
611
612 #[cfg(feature = "download")]
614 pub async fn load_from_url(url: &str, config: Option<ExternalConfig>) -> Result<Dataset> {
615 let client = match config {
616 Some(cfg) => ExternalClient::with_config(cfg)?,
617 None => ExternalClient::new()?,
618 };
619
620 client
621 .download_dataset(
622 url,
623 Some(Box::new(|downloaded, total| {
624 if let Some(percent) = (downloaded * 100).checked_div(total) {
625 eprintln!("Downloaded: {percent:.1}% ({downloaded}/{total})");
626 } else {
627 eprintln!("Downloaded: {downloaded} bytes");
628 }
629 })),
630 )
631 .await
632 }
633
634 pub fn load_from_url_sync(url: &str, config: Option<ExternalConfig>) -> Result<Dataset> {
636 let client = match config {
637 Some(cfg) => ExternalClient::with_config(cfg)?,
638 None => ExternalClient::new()?,
639 };
640
641 client.download_dataset_sync(
642 url,
643 Some(Box::new(|downloaded, total| {
644 if let Some(percent) = (downloaded * 100).checked_div(total) {
645 eprintln!("Downloaded: {percent:.1}% ({downloaded}/{total})");
646 } else {
647 eprintln!("Downloaded: {downloaded} bytes");
648 }
649 })),
650 )
651 }
652
653 #[cfg(feature = "download")]
655 pub async fn load_uci_dataset(name: &str) -> Result<Dataset> {
656 let repo = UCIRepository::new()?;
657 repo.load_dataset(name).await
658 }
659
660 #[cfg(not(feature = "download"))]
662 pub fn load_uci_dataset_sync(name: &str) -> Result<Dataset> {
663 let repo = UCIRepository::new()?;
664 repo.load_dataset_sync(name)
665 }
666
667 #[cfg(feature = "download")]
669 pub async fn load_github_dataset(user: &str, repo: &str, path: &str) -> Result<Dataset> {
670 let github = GitHubRepository::new()?;
671 github.load_from_repo(user, repo, path).await
672 }
673
674 #[cfg(not(feature = "download"))]
676 pub fn load_github_dataset_sync(user: &str, repo: &str, path: &str) -> Result<Dataset> {
677 let github = GitHubRepository::new()?;
678 github.load_from_repo_sync(user, repo, path)
679 }
680
681 pub fn list_uci_datasets() -> Result<Vec<&'static str>> {
683 let repo = UCIRepository::new()?;
684 Ok(repo.list_datasets())
685 }
686}
687
688#[cfg(test)]
689mod tests {
690 use super::convenience::*;
691 use super::*;
692
693 #[test]
694 fn test_external_config_default() {
695 let config = ExternalConfig::default();
696 assert_eq!(config.timeout_seconds, 300);
697 assert_eq!(config.max_retries, 3);
698 assert!(config.verify_ssl);
699 assert!(config.use_cache);
700 }
701
702 #[test]
703 fn test_uci_repository_list_datasets() {
704 let datasets = list_uci_datasets().expect("Operation failed");
705 assert!(!datasets.is_empty());
706 assert!(datasets.contains(&"wine"));
707 assert!(datasets.contains(&"adult"));
708 }
709
710 #[test]
711 fn test_parse_arff_data() {
712 let arff_content = r#"
713@relation test
714@attribute feature1 numeric
715@attribute feature2 numeric
716@attribute class {0,1}
717@data
7181.0,2.0,0
7193.0,4.0,1
7205.0,6.0,0
721"#;
722
723 let client = ExternalClient::new().expect("Operation failed");
724 let dataset = client
725 .parse_arff_data(arff_content.as_bytes())
726 .expect("Operation failed");
727
728 assert_eq!(dataset.n_samples(), 3);
729 assert_eq!(dataset.n_features(), 2);
730 assert!(dataset.target.is_some());
731 }
732
733 #[tokio::test]
734 #[cfg(feature = "download")]
735 async fn test_download_small_csv() {
736 let url = "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv";
738
739 let result = load_from_url(url, None).await;
740 match result {
741 Ok(dataset) => {
742 assert!(dataset.n_samples() > 0);
743 assert!(dataset.n_features() > 0);
744 }
745 Err(e) => {
746 eprintln!("Network test failed (expected in CI): {}", e);
748 }
749 }
750 }
751}