Skip to main content

scirs2_datasets/
external.rs

1//! External data sources integration
2//!
3//! This module provides functionality for loading datasets from external sources including:
4//! - URLs and web resources
5//! - API endpoints
6//! - Popular dataset repositories
7//! - Remote file systems
8
9use std::collections::HashMap;
10use std::io::Read;
11use std::path::Path;
12use std::time::Duration;
13
14use scirs2_core::ndarray::{Array1, Array2};
15use serde::{Deserialize, Serialize};
16
17use crate::cache::DatasetCache;
18use crate::error::{DatasetsError, Result};
19use crate::loaders::{load_csv, CsvConfig};
20use crate::utils::Dataset;
21
22/// Configuration for external data source access
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ExternalConfig {
25    /// Timeout for requests (in seconds)
26    pub timeout_seconds: u64,
27    /// Number of retry attempts
28    pub max_retries: u32,
29    /// User agent string for requests
30    pub user_agent: String,
31    /// Headers to include in requests
32    pub headers: HashMap<String, String>,
33    /// Whether to verify SSL certificates
34    pub verify_ssl: bool,
35    /// Cache downloaded files
36    pub use_cache: bool,
37}
38
39impl Default for ExternalConfig {
40    fn default() -> Self {
41        Self {
42            timeout_seconds: 300, // 5 minutes
43            max_retries: 3,
44            user_agent: "scirs2-datasets/0.1.0".to_string(),
45            headers: HashMap::new(),
46            verify_ssl: true,
47            use_cache: true,
48        }
49    }
50}
51
52/// Progress callback for download operations
53pub type ProgressCallback = Box<dyn Fn(u64, u64) + Send + Sync>;
54
55/// External data source client
56pub struct ExternalClient {
57    config: ExternalConfig,
58    cache: DatasetCache,
59    #[cfg(feature = "download")]
60    client: reqwest::Client,
61}
62
63impl ExternalClient {
64    /// Create a new external client with default configuration
65    pub fn new() -> Result<Self> {
66        Self::with_config(ExternalConfig::default())
67    }
68
69    /// Create a new external client with custom configuration
70    pub fn with_config(config: ExternalConfig) -> Result<Self> {
71        let cache = DatasetCache::new(crate::cache::get_cachedir()?);
72
73        #[cfg(feature = "download")]
74        let client = {
75            let mut builder = reqwest::Client::builder()
76                .timeout(Duration::from_secs(config.timeout_seconds))
77                .user_agent(&config.user_agent);
78
79            if !config.verify_ssl {
80                builder = builder.danger_accept_invalid_certs(true);
81            }
82
83            builder
84                .build()
85                .map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?
86        };
87
88        Ok(Self {
89            config,
90            cache,
91            #[cfg(feature = "download")]
92            client,
93        })
94    }
95
96    /// Download a dataset from a URL
97    #[cfg(feature = "download")]
98    pub async fn download_dataset(
99        &self,
100        url: &str,
101        progress: Option<ProgressCallback>,
102    ) -> Result<Dataset> {
103        // Check cache first
104        if self.config.use_cache {
105            let cache_key = format!("external_{}", blake3::hash(url.as_bytes()).to_hex());
106            if let Ok(cached_data) = self.cache.read_cached(&cache_key) {
107                return self.parse_cached_data(&cached_data);
108            }
109        }
110
111        // Download the file
112        let response = self.make_request(url).await?;
113        let total_size = response.content_length().unwrap_or(0);
114
115        let mut downloaded = 0u64;
116        let mut buffer = Vec::new();
117        let mut stream = response.bytes_stream();
118
119        use futures_util::StreamExt;
120        while let Some(chunk) = stream.next().await {
121            let chunk = chunk.map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?;
122            downloaded += chunk.len() as u64;
123            buffer.extend_from_slice(&chunk);
124
125            if let Some(ref callback) = progress {
126                callback(downloaded, total_size);
127            }
128        }
129
130        // Cache the downloaded data
131        if self.config.use_cache {
132            let cache_key = format!("external_{}", blake3::hash(url.as_bytes()).to_hex());
133            let _ = self.cache.put(&cache_key, &buffer);
134        }
135
136        // Parse the data based on content type or URL extension
137        self.parse_downloaded_data(url, &buffer)
138    }
139
140    /// Download a dataset synchronously (blocking) - when download feature is enabled
141    #[cfg(feature = "download")]
142    pub fn download_dataset_sync(
143        &self,
144        url: &str,
145        progress: Option<ProgressCallback>,
146    ) -> Result<Dataset> {
147        // Use tokio runtime to block on the async version
148        let rt = tokio::runtime::Runtime::new()
149            .map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?;
150        rt.block_on(self.download_dataset(url, progress))
151    }
152
153    /// Download a dataset synchronously (blocking) - fallback when download feature is disabled
154    #[cfg(not(feature = "download"))]
155    #[cfg(feature = "download-sync")]
156    pub fn download_dataset_sync(
157        &self,
158        url: &str,
159        progress: Option<ProgressCallback>,
160    ) -> Result<Dataset> {
161        // Fallback implementation using ureq
162        self.download_with_ureq(url, progress)
163    }
164
165    /// Stub for download_dataset_sync when download-sync feature is disabled
166    #[cfg(not(feature = "download"))]
167    #[cfg(not(feature = "download-sync"))]
168    pub fn download_dataset_sync(
169        &self,
170        _url: &str,
171        _progress: Option<ProgressCallback>,
172    ) -> Result<Dataset> {
173        Err(DatasetsError::FormatError(
174            "Synchronous download feature is disabled. Enable 'download-sync' feature or use async download.".to_string()
175        ))
176    }
177
178    /// Download using ureq (synchronous HTTP client)
179    #[cfg(feature = "download-sync")]
180    #[allow(dead_code)]
181    fn download_with_ureq(&self, url: &str, progress: Option<ProgressCallback>) -> Result<Dataset> {
182        // Check cache first
183        if self.config.use_cache {
184            let cache_key = format!("external_{}", blake3::hash(url.as_bytes()).to_hex());
185            if let Ok(cached_data) = self.cache.read_cached(&cache_key) {
186                return self.parse_cached_data(&cached_data);
187            }
188        }
189
190        let mut request = ureq::get(url).header("User-Agent", &self.config.user_agent);
191
192        // Add custom headers
193        for (key, value) in &self.config.headers {
194            request = request.header(key, value);
195        }
196
197        let response = request
198            .call()
199            .map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?;
200
201        // Get content-length header if present (case-insensitive per HTTP spec)
202        let headers = response.headers();
203        let total_size = headers
204            .get("Content-Length")
205            .and_then(|hv| hv.to_str().ok())
206            .and_then(|s| s.parse::<u64>().ok())
207            .unwrap_or(0);
208
209        // Read body via body reader (ureq 3.x)
210        let mut body = response.into_body();
211        let buffer = body
212            .read_to_vec()
213            .map_err(|e| DatasetsError::IoError(std::io::Error::other(e)))?;
214        let downloaded = buffer.len() as u64;
215        if let Some(ref callback) = progress {
216            callback(downloaded, total_size);
217        }
218
219        // Cache the downloaded data
220        if self.config.use_cache {
221            let cache_key = format!("external_{}", blake3::hash(url.as_bytes()).to_hex());
222            let _ = self.cache.put(&cache_key, &buffer);
223        }
224
225        // Parse the data
226        self.parse_downloaded_data(url, &buffer)
227    }
228
229    #[cfg(feature = "download")]
230    async fn make_request(&self, url: &str) -> Result<reqwest::Response> {
231        let mut request = self.client.get(url);
232
233        // Add custom headers
234        for (key, value) in &self.config.headers {
235            request = request.header(key, value);
236        }
237
238        let mut last_error = None;
239
240        for attempt in 0..=self.config.max_retries {
241            match request
242                .try_clone()
243                .ok_or_else(|| {
244                    DatasetsError::IoError(std::io::Error::other("Failed to clone request"))
245                })?
246                .send()
247                .await
248            {
249                Ok(response) => {
250                    if response.status().is_success() {
251                        return Ok(response);
252                    } else {
253                        last_error = Some(DatasetsError::IoError(std::io::Error::other(format!(
254                            "HTTP {}: {}",
255                            response.status(),
256                            response.status().canonical_reason().unwrap_or("Unknown")
257                        ))));
258                    }
259                }
260                Err(e) => {
261                    last_error = Some(DatasetsError::IoError(std::io::Error::other(e)));
262                }
263            }
264
265            if attempt < self.config.max_retries {
266                tokio::time::sleep(Duration::from_millis(1000 * (attempt + 1) as u64)).await;
267            }
268        }
269
270        Err(last_error.expect("Operation failed"))
271    }
272
273    fn parse_cached_data(&self, data: &[u8]) -> Result<Dataset> {
274        // Try to deserialize as JSON first (cached parsed data)
275        if let Ok(dataset) = serde_json::from_slice::<Dataset>(data) {
276            return Ok(dataset);
277        }
278
279        // Otherwise parse as raw data
280        self.parse_raw_data(data, None)
281    }
282
283    fn parse_downloaded_data(&self, url: &str, data: &[u8]) -> Result<Dataset> {
284        let extension = Path::new(url)
285            .extension()
286            .and_then(|s| s.to_str())
287            .unwrap_or("")
288            .to_lowercase();
289
290        self.parse_raw_data(data, Some(&extension))
291    }
292
293    fn parse_raw_data(&self, data: &[u8], extension: Option<&str>) -> Result<Dataset> {
294        match extension {
295            Some("csv") | None => {
296                // Try CSV parsing
297                let csv_data = String::from_utf8(data.to_vec())
298                    .map_err(|e| DatasetsError::FormatError(format!("Invalid UTF-8: {e}")))?;
299
300                // Write to temporary file for CSV parsing
301                let temp_file = tempfile::NamedTempFile::new().map_err(DatasetsError::IoError)?;
302
303                std::fs::write(temp_file.path(), &csv_data).map_err(DatasetsError::IoError)?;
304
305                load_csv(temp_file.path(), CsvConfig::default())
306            }
307            Some("json") => {
308                // Try JSON parsing
309                let json_str = String::from_utf8(data.to_vec())
310                    .map_err(|e| DatasetsError::FormatError(format!("Invalid UTF-8: {e}")))?;
311
312                serde_json::from_str(&json_str)
313                    .map_err(|e| DatasetsError::FormatError(format!("Invalid JSON: {e}")))
314            }
315            Some("arff") => {
316                // Basic ARFF parsing (simplified)
317                self.parse_arff_data(data)
318            }
319            _ => {
320                // Try to auto-detect format
321                self.auto_detect_and_parse(data)
322            }
323        }
324    }
325
326    fn parse_arff_data(&self, data: &[u8]) -> Result<Dataset> {
327        let content = String::from_utf8(data.to_vec())
328            .map_err(|e| DatasetsError::FormatError(format!("Invalid UTF-8: {e}")))?;
329
330        let lines = content.lines();
331        let mut attributes = Vec::new();
332        let mut data_section = false;
333        let mut data_lines = Vec::new();
334
335        for line in lines {
336            let line = line.trim();
337
338            if line.is_empty() || line.starts_with('%') {
339                continue;
340            }
341
342            if line.to_lowercase().starts_with("@attribute") {
343                let parts: Vec<&str> = line.split_whitespace().collect();
344                if parts.len() >= 2 {
345                    attributes.push(parts[1].to_string());
346                }
347            } else if line.to_lowercase().starts_with("@data") {
348                data_section = true;
349            } else if data_section {
350                data_lines.push(line.to_string());
351            }
352        }
353
354        // Parse data rows
355        let mut rows: Vec<Vec<f64>> = Vec::new();
356        for line in data_lines {
357            let values: Result<Vec<f64>> = line
358                .split(',')
359                .map(|s| {
360                    s.trim()
361                        .parse::<f64>()
362                        .map_err(|_| DatasetsError::FormatError(format!("Invalid number: {s}")))
363                })
364                .collect();
365
366            match values {
367                Ok(row) => rows.push(row),
368                Err(_) => continue, // Skip invalid rows
369            }
370        }
371
372        if rows.is_empty() {
373            return Err(DatasetsError::FormatError(
374                "No valid data rows found".to_string(),
375            ));
376        }
377
378        let n_features = rows[0].len();
379        let n_samples = rows.len();
380
381        // Assume last column is target if more than one column
382        let (data_cols, target_col) = if n_features > 1 {
383            (n_features - 1, Some(n_features - 1))
384        } else {
385            (n_features, None)
386        };
387
388        // Create data array
389        let mut data_vec = Vec::with_capacity(n_samples * data_cols);
390        let mut target_vec = if target_col.is_some() {
391            Some(Vec::with_capacity(n_samples))
392        } else {
393            None
394        };
395
396        for row in rows {
397            for (i, &value) in row.iter().enumerate() {
398                if i < data_cols {
399                    data_vec.push(value);
400                } else if let Some(ref mut targets) = target_vec {
401                    targets.push(value);
402                }
403            }
404        }
405
406        let data = Array2::from_shape_vec((n_samples, data_cols), data_vec)
407            .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
408
409        let target = target_vec.map(Array1::from_vec);
410
411        Ok(Dataset {
412            data,
413            target,
414            featurenames: Some(attributes[..data_cols].to_vec()),
415            targetnames: None,
416            feature_descriptions: None,
417            description: Some("ARFF dataset loaded from external source".to_string()),
418            metadata: std::collections::HashMap::new(),
419        })
420    }
421
422    fn auto_detect_and_parse(&self, data: &[u8]) -> Result<Dataset> {
423        let content = String::from_utf8(data.to_vec())
424            .map_err(|e| DatasetsError::FormatError(format!("Invalid UTF-8: {e}")))?;
425
426        // Try JSON first
427        if content.trim().starts_with('{') || content.trim().starts_with('[') {
428            if let Ok(dataset) = serde_json::from_str::<Dataset>(&content) {
429                return Ok(dataset);
430            }
431        }
432
433        // Try CSV
434        if content.contains(',') || content.contains('\t') {
435            return self.parse_raw_data(data, Some("csv"));
436        }
437
438        // Try ARFF
439        if content.to_lowercase().contains("@relation") {
440            return self.parse_arff_data(data);
441        }
442
443        Err(DatasetsError::FormatError(
444            "Unable to auto-detect data format".to_string(),
445        ))
446    }
447}
448
449/// Popular dataset repository APIs
450pub mod repositories {
451    use super::*;
452
453    /// UCI Machine Learning Repository client
454    pub struct UCIRepository {
455        client: ExternalClient,
456        base_url: String,
457    }
458
459    impl UCIRepository {
460        /// Create a new UCI repository client
461        pub fn new() -> Result<Self> {
462            Ok(Self {
463                client: ExternalClient::new()?,
464                base_url: "https://archive.ics.uci.edu/ml/machine-learning-databases".to_string(),
465            })
466        }
467
468        /// Loads a dataset from the UCI Machine Learning Repository.
469        ///
470        /// # Arguments
471        /// * `name` - The name of the dataset to load
472        ///
473        /// # Returns
474        /// A `Dataset` containing the loaded data
475        #[cfg(feature = "download")]
476        pub async fn load_dataset(&self, name: &str) -> Result<Dataset> {
477            let url = match name {
478                "adult" => format!("{}/adult/adult.data", self.base_url),
479                "wine" => format!("{}/wine/wine.data", self.base_url),
480                "glass" => format!("{}/glass/glass.data", self.base_url),
481                "hepatitis" => format!("{}/hepatitis/hepatitis.data", self.base_url),
482                "heart-disease" => {
483                    format!("{}/heart-disease/processed.cleveland.data", self.base_url)
484                }
485                _ => {
486                    return Err(DatasetsError::NotFound(format!(
487                        "UCI dataset '{name}' not found"
488                    )))
489                }
490            };
491
492            self.client.download_dataset(&url, None).await
493        }
494
495        #[cfg(not(feature = "download"))]
496        /// Load a UCI dataset synchronously
497        pub fn load_dataset_sync(&self, name: &str) -> Result<Dataset> {
498            let url = match name {
499                "adult" => format!("{}/adult/adult.data", self.base_url),
500                "wine" => format!("{}/wine/wine.data", self.base_url),
501                "glass" => format!("{}/glass/glass.data", self.base_url),
502                "hepatitis" => format!("{}/hepatitis/hepatitis.data", self.base_url),
503                "heart-disease" => {
504                    format!("{}/heart-disease/processed.cleveland.data", self.base_url)
505                }
506                _ => {
507                    return Err(DatasetsError::NotFound(format!(
508                        "UCI dataset '{name}' not found"
509                    )))
510                }
511            };
512
513            self.client.download_dataset_sync(&url, None)
514        }
515
516        /// List available UCI datasets
517        pub fn list_datasets(&self) -> Vec<&'static str> {
518            vec!["adult", "wine", "glass", "hepatitis", "heart-disease"]
519        }
520    }
521
522    /// Kaggle dataset client (requires API key)
523    pub struct KaggleRepository {
524        #[allow(dead_code)]
525        client: ExternalClient,
526        #[allow(dead_code)]
527        api_key: Option<String>,
528    }
529
530    impl KaggleRepository {
531        /// Create a new Kaggle repository client
532        pub fn new(_apikey: Option<String>) -> Result<Self> {
533            let mut config = ExternalConfig::default();
534
535            if let Some(ref key) = _apikey {
536                config
537                    .headers
538                    .insert("Authorization".to_string(), format!("Bearer {key}"));
539            }
540
541            Ok(Self {
542                client: ExternalClient::with_config(config)?,
543                api_key: _apikey,
544            })
545        }
546
547        /// Loads competition data from Kaggle.
548        ///
549        /// # Arguments
550        /// * `competition` - The name of the Kaggle competition
551        ///
552        /// # Returns
553        /// A `Dataset` containing the competition data
554        #[cfg(feature = "download")]
555        pub async fn load_competition_data(&self, competition: &str) -> Result<Dataset> {
556            if self.api_key.is_none() {
557                return Err(DatasetsError::AuthenticationError(
558                    "Kaggle API key required".to_string(),
559                ));
560            }
561
562            let url = format!(
563                "https://www.kaggle.com/api/v1/competitions/{}/data/download",
564                competition
565            );
566            self.client.download_dataset(&url, None).await
567        }
568    }
569
570    /// GitHub repository client for datasets
571    pub struct GitHubRepository {
572        client: ExternalClient,
573    }
574
575    impl GitHubRepository {
576        /// Create a new GitHub repository client
577        pub fn new() -> Result<Self> {
578            Ok(Self {
579                client: ExternalClient::new()?,
580            })
581        }
582
583        /// Loads a dataset from a GitHub repository.
584        ///
585        /// # Arguments
586        /// * `user` - The GitHub username
587        /// * `repo` - The repository name
588        /// * `path` - The path to the dataset file within the repository
589        ///
590        /// # Returns
591        /// A `Dataset` containing the loaded data
592        #[cfg(feature = "download")]
593        pub async fn load_from_repo(&self, user: &str, repo: &str, path: &str) -> Result<Dataset> {
594            let url = format!("https://raw.githubusercontent.com/{user}/{repo}/main/{path}");
595            self.client.download_dataset(&url, None).await
596        }
597
598        #[cfg(not(feature = "download"))]
599        /// Load a dataset from GitHub repository synchronously
600        pub fn load_from_repo_sync(&self, user: &str, repo: &str, path: &str) -> Result<Dataset> {
601            let url = format!("https://raw.githubusercontent.com/{user}/{repo}/main/{path}");
602            self.client.download_dataset_sync(&url, None)
603        }
604    }
605}
606
607/// Convenience functions for common external data operations
608pub mod convenience {
609    use super::repositories::*;
610    use super::*;
611
612    /// Load a dataset from a URL with progress tracking
613    #[cfg(feature = "download")]
614    pub async fn load_from_url(url: &str, config: Option<ExternalConfig>) -> Result<Dataset> {
615        let client = match config {
616            Some(cfg) => ExternalClient::with_config(cfg)?,
617            None => ExternalClient::new()?,
618        };
619
620        client
621            .download_dataset(
622                url,
623                Some(Box::new(|downloaded, total| {
624                    if let Some(percent) = (downloaded * 100).checked_div(total) {
625                        eprintln!("Downloaded: {percent:.1}% ({downloaded}/{total})");
626                    } else {
627                        eprintln!("Downloaded: {downloaded} bytes");
628                    }
629                })),
630            )
631            .await
632    }
633
634    /// Load a dataset from a URL synchronously
635    pub fn load_from_url_sync(url: &str, config: Option<ExternalConfig>) -> Result<Dataset> {
636        let client = match config {
637            Some(cfg) => ExternalClient::with_config(cfg)?,
638            None => ExternalClient::new()?,
639        };
640
641        client.download_dataset_sync(
642            url,
643            Some(Box::new(|downloaded, total| {
644                if let Some(percent) = (downloaded * 100).checked_div(total) {
645                    eprintln!("Downloaded: {percent:.1}% ({downloaded}/{total})");
646                } else {
647                    eprintln!("Downloaded: {downloaded} bytes");
648                }
649            })),
650        )
651    }
652
653    /// Load a UCI dataset by name
654    #[cfg(feature = "download")]
655    pub async fn load_uci_dataset(name: &str) -> Result<Dataset> {
656        let repo = UCIRepository::new()?;
657        repo.load_dataset(name).await
658    }
659
660    /// Load a UCI dataset by name synchronously
661    #[cfg(not(feature = "download"))]
662    pub fn load_uci_dataset_sync(name: &str) -> Result<Dataset> {
663        let repo = UCIRepository::new()?;
664        repo.load_dataset_sync(name)
665    }
666
667    /// Load a dataset from GitHub repository
668    #[cfg(feature = "download")]
669    pub async fn load_github_dataset(user: &str, repo: &str, path: &str) -> Result<Dataset> {
670        let github = GitHubRepository::new()?;
671        github.load_from_repo(user, repo, path).await
672    }
673
674    /// Load a dataset from GitHub repository synchronously
675    #[cfg(not(feature = "download"))]
676    pub fn load_github_dataset_sync(user: &str, repo: &str, path: &str) -> Result<Dataset> {
677        let github = GitHubRepository::new()?;
678        github.load_from_repo_sync(user, repo, path)
679    }
680
681    /// List available UCI datasets
682    pub fn list_uci_datasets() -> Result<Vec<&'static str>> {
683        let repo = UCIRepository::new()?;
684        Ok(repo.list_datasets())
685    }
686}
687
688#[cfg(test)]
689mod tests {
690    use super::convenience::*;
691    use super::*;
692
693    #[test]
694    fn test_external_config_default() {
695        let config = ExternalConfig::default();
696        assert_eq!(config.timeout_seconds, 300);
697        assert_eq!(config.max_retries, 3);
698        assert!(config.verify_ssl);
699        assert!(config.use_cache);
700    }
701
702    #[test]
703    fn test_uci_repository_list_datasets() {
704        let datasets = list_uci_datasets().expect("Operation failed");
705        assert!(!datasets.is_empty());
706        assert!(datasets.contains(&"wine"));
707        assert!(datasets.contains(&"adult"));
708    }
709
710    #[test]
711    fn test_parse_arff_data() {
712        let arff_content = r#"
713@relation test
714@attribute feature1 numeric
715@attribute feature2 numeric
716@attribute class {0,1}
717@data
7181.0,2.0,0
7193.0,4.0,1
7205.0,6.0,0
721"#;
722
723        let client = ExternalClient::new().expect("Operation failed");
724        let dataset = client
725            .parse_arff_data(arff_content.as_bytes())
726            .expect("Operation failed");
727
728        assert_eq!(dataset.n_samples(), 3);
729        assert_eq!(dataset.n_features(), 2);
730        assert!(dataset.target.is_some());
731    }
732
733    #[tokio::test]
734    #[cfg(feature = "download")]
735    async fn test_download_small_csv() {
736        // Test with a small public CSV dataset
737        let url = "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv";
738
739        let result = load_from_url(url, None).await;
740        match result {
741            Ok(dataset) => {
742                assert!(dataset.n_samples() > 0);
743                assert!(dataset.n_features() > 0);
744            }
745            Err(e) => {
746                // Network tests may fail in CI, so we just log the error
747                eprintln!("Network test failed (expected in CI): {}", e);
748            }
749        }
750    }
751}