1use std::collections::HashMap;
13
14use scirs2_core::ndarray::{Array1, Array2};
15use scirs2_core::random::prelude::*;
16use scirs2_core::random::{Distribution, Uniform};
17use serde::{Deserialize, Serialize};
18
19use crate::cache::DatasetCache;
20use crate::error::{DatasetsError, Result};
21use crate::external::ExternalClient;
22use crate::utils::Dataset;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct DomainConfig {
27 pub base_url: Option<String>,
29 pub api_key: Option<String>,
31 pub preferred_formats: Vec<String>,
33 pub quality_filters: QualityFilters,
35}
36
37impl Default for DomainConfig {
38 fn default() -> Self {
39 Self {
40 base_url: None,
41 api_key: None,
42 preferred_formats: vec!["csv".to_string(), "fits".to_string(), "hdf5".to_string()],
43 quality_filters: QualityFilters::default(),
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct QualityFilters {
51 pub min_samples: Option<usize>,
53 pub max_missing_percent: Option<f64>,
55 pub min_completeness: Option<f64>,
57 pub min_year: Option<u32>,
59}
60
61impl Default for QualityFilters {
62 fn default() -> Self {
63 Self {
64 min_samples: Some(100),
65 max_missing_percent: Some(0.1),
66 min_completeness: Some(0.9),
67 min_year: Some(2000),
68 }
69 }
70}
71
72pub mod astronomy {
74 use super::*;
75
76 pub struct StellarDatasets {
78 #[allow(dead_code)]
79 client: ExternalClient,
80 #[allow(dead_code)]
81 cache: DatasetCache,
82 }
83
84 impl StellarDatasets {
85 pub fn new() -> Result<Self> {
87 let cachedir = dirs::cache_dir()
88 .ok_or_else(|| {
89 DatasetsError::Other("Could not determine cache directory".to_string())
90 })?
91 .join("scirs2-datasets");
92 Ok(Self {
93 client: ExternalClient::new()?,
94 cache: DatasetCache::new(cachedir),
95 })
96 }
97
98 pub fn load_hipparcos_catalog(&self) -> Result<Dataset> {
100 self.load_synthetic_stellar_data("hipparcos", 118218)
101 }
102
103 pub fn load_gaia_dr3_sample(&self) -> Result<Dataset> {
105 self.load_synthetic_stellar_data("gaia_dr3", 50000)
106 }
107
108 pub fn load_exoplanet_catalog(&self) -> Result<Dataset> {
110 self.load_synthetic_exoplanet_data(5000)
111 }
112
113 pub fn load_supernova_photometry(&self) -> Result<Dataset> {
115 self.load_synthetic_supernova_data(1000)
116 }
117
118 fn load_synthetic_stellar_data(&self, catalog: &str, nstars: usize) -> Result<Dataset> {
119 use scirs2_core::random::{Distribution, Normal};
120
121 let mut rng = thread_rng();
122
123 let mut data = Vec::with_capacity(nstars * 8);
125 let mut spectral_classes = Vec::with_capacity(nstars);
126
127 let ra_dist = scirs2_core::random::Uniform::new(0.0, 360.0).expect("Operation failed");
129 let dec_dist =
130 scirs2_core::random::Uniform::new(-90.0, 90.0).expect("Operation failed");
131 let magnitude_dist = Normal::new(8.0, 3.0).expect("Operation failed");
132 let color_dist = Normal::new(0.5, 0.3).expect("Operation failed");
133 let parallax_dist = Normal::new(10.0, 5.0).expect("Operation failed");
134 let proper_motion_dist = Normal::new(0.0, 50.0).expect("Operation failed");
135 let radial_velocity_dist = Normal::new(0.0, 30.0).expect("Operation failed");
136
137 for _ in 0..nstars {
138 data.push(ra_dist.sample(&mut rng));
140 data.push(dec_dist.sample(&mut rng));
142 data.push(magnitude_dist.sample(&mut rng));
144 data.push(color_dist.sample(&mut rng));
146 data.push((parallax_dist.sample(&mut rng) as f64).max(0.1f64));
148 data.push(proper_motion_dist.sample(&mut rng));
150 data.push(proper_motion_dist.sample(&mut rng));
152 data.push(radial_velocity_dist.sample(&mut rng));
154
155 let color = data[data.len() - 5];
157 let spectral_class = match color {
158 c if c < -0.3 => 0, c if c < -0.1 => 1, c if c < 0.2 => 2, c if c < 0.5 => 3, c if c < 0.8 => 4, c if c < 1.2 => 5, _ => 6, };
166 spectral_classes.push(spectral_class as f64);
167 }
168
169 let data_array = Array2::from_shape_vec((nstars, 8), data)
170 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
171
172 let target = Array1::from_vec(spectral_classes);
173
174 Ok(Dataset {
175 data: data_array,
176 target: Some(target),
177 featurenames: Some(vec![
178 "ra".to_string(),
179 "dec".to_string(),
180 "magnitude".to_string(),
181 "color_bv".to_string(),
182 "parallax".to_string(),
183 "pm_ra".to_string(),
184 "pm_dec".to_string(),
185 "radial_velocity".to_string(),
186 ]),
187 targetnames: Some(vec![
188 "O".to_string(),
189 "B".to_string(),
190 "A".to_string(),
191 "F".to_string(),
192 "G".to_string(),
193 "K".to_string(),
194 "M".to_string(),
195 ]),
196 feature_descriptions: Some(vec![
197 "Right Ascension (degrees)".to_string(),
198 "Declination (degrees)".to_string(),
199 "Apparent magnitude (visual)".to_string(),
200 "B-V color index".to_string(),
201 "Parallax (arcseconds)".to_string(),
202 "Proper motion RA (mas/year)".to_string(),
203 "Proper motion Dec (mas/year)".to_string(),
204 "Radial velocity (km/s)".to_string(),
205 ]),
206 description: Some(format!(
207 "Synthetic {catalog} stellar catalog with {nstars} _stars"
208 )),
209 metadata: std::collections::HashMap::new(),
210 })
211 }
212
213 fn load_synthetic_exoplanet_data(&self, nplanets: usize) -> Result<Dataset> {
214 use scirs2_core::random::{Distribution, LogNormal, Normal};
215
216 let mut rng = thread_rng();
217
218 let mut data = Vec::with_capacity(nplanets * 6);
220 let mut planet_types = Vec::with_capacity(nplanets);
221
222 let period_dist = LogNormal::new(1.0, 1.5).expect("Operation failed");
224 let radius_dist = LogNormal::new(0.0, 0.8).expect("Operation failed");
225 let mass_dist = LogNormal::new(1.0, 1.2).expect("Operation failed");
226 let stellar_mass_dist = Normal::new(1.0, 0.3).expect("Operation failed");
227 let stellar_temp_dist = Normal::new(5800.0, 1000.0).expect("Operation failed");
228 let metallicity_dist = Normal::new(0.0, 0.3).expect("Operation failed");
229
230 for _ in 0..nplanets {
231 data.push(period_dist.sample(&mut rng));
233 data.push(radius_dist.sample(&mut rng));
235 data.push(mass_dist.sample(&mut rng));
237 data.push((stellar_mass_dist.sample(&mut rng) as f64).max(0.1f64));
239 data.push(stellar_temp_dist.sample(&mut rng));
241 data.push(metallicity_dist.sample(&mut rng));
243
244 let radius = data[data.len() - 5];
246 let planet_type = match radius {
247 r if r < 1.25 => 0, r if r < 2.0 => 1, r if r < 4.0 => 2, r if r < 11.0 => 3, _ => 4, };
253 planet_types.push(planet_type as f64);
254 }
255
256 let data_array = Array2::from_shape_vec((nplanets, 6), data)
257 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
258
259 let target = Array1::from_vec(planet_types);
260
261 Ok(Dataset {
262 data: data_array,
263 target: Some(target),
264 featurenames: Some(vec![
265 "period".to_string(),
266 "radius".to_string(),
267 "mass".to_string(),
268 "stellar_mass".to_string(),
269 "stellar_temp".to_string(),
270 "metallicity".to_string(),
271 ]),
272 targetnames: Some(vec![
273 "Rocky".to_string(),
274 "Super-Earth".to_string(),
275 "Sub-Neptune".to_string(),
276 "Neptune".to_string(),
277 "Jupiter".to_string(),
278 ]),
279 feature_descriptions: Some(vec![
280 "Orbital period (days)".to_string(),
281 "Planet radius (Earth radii)".to_string(),
282 "Planet mass (Earth masses)".to_string(),
283 "Stellar mass (Solar masses)".to_string(),
284 "Stellar temperature (K)".to_string(),
285 "Stellar metallicity [Fe/H]".to_string(),
286 ]),
287 description: Some(format!(
288 "Synthetic exoplanet catalog with {nplanets} _planets"
289 )),
290 metadata: std::collections::HashMap::new(),
291 })
292 }
293
294 fn load_synthetic_supernova_data(&self, nsupernovae: usize) -> Result<Dataset> {
295 use scirs2_core::random::{Distribution, Normal};
296
297 let mut rng = thread_rng();
298
299 let mut data = Vec::with_capacity(nsupernovae * 10);
301 let mut sn_types = Vec::with_capacity(nsupernovae);
302
303 let _type_probs = [0.7, 0.15, 0.10, 0.05]; for _ in 0..nsupernovae {
307 let sn_type = rng.sample(Uniform::new(0, 4).expect("Operation failed"));
308
309 let (peak_mag, decline_rate, color_evolution, host_mass) = match sn_type {
310 0 => (-19.3, 1.1, 0.2, 10.5), 1 => (-18.5, 1.8, 0.5, 9.8), 2 => (-16.8, 0.8, 0.3, 9.2), _ => (-17.5, 1.2, 0.4, 9.0), };
315
316 let peak_noise = Normal::new(0.0, 0.3).expect("Operation failed");
318 let decline_noise = Normal::new(0.0, 0.2).expect("Operation failed");
319 let color_noise = Normal::new(0.0, 0.1).expect("Operation failed");
320 let host_noise = Normal::new(0.0, 0.5).expect("Operation failed");
321
322 data.push(peak_mag + peak_noise.sample(&mut rng));
324 data.push(decline_rate + decline_noise.sample(&mut rng));
326 data.push(color_evolution + color_noise.sample(&mut rng));
328 data.push(host_mass + host_noise.sample(&mut rng));
330 data.push(rng.random_range(0.01..0.3));
332 data.push(rng.random_range(20.0..200.0));
334 data.push(rng.random_range(0.7..1.3));
336 data.push(rng.random_range(0.0..0.5));
338 data.push(rng.random_range(15.0..22.0));
340 data.push(rng.random_range(-90.0..90.0));
342
343 sn_types.push(sn_type as f64);
344 }
345
346 let data_array = Array2::from_shape_vec((nsupernovae, 10), data)
347 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
348
349 let target = Array1::from_vec(sn_types);
350
351 Ok(Dataset {
352 data: data_array,
353 target: Some(target),
354 featurenames: Some(vec![
355 "peak_magnitude".to_string(),
356 "decline_rate".to_string(),
357 "color_max".to_string(),
358 "host_mass".to_string(),
359 "redshift".to_string(),
360 "duration".to_string(),
361 "stretch".to_string(),
362 "color_excess".to_string(),
363 "discovery_mag".to_string(),
364 "galactic_lat".to_string(),
365 ]),
366 targetnames: Some(vec![
367 "Type Ia".to_string(),
368 "Type Ib/c".to_string(),
369 "Type II-P".to_string(),
370 "Type II-L".to_string(),
371 ]),
372 feature_descriptions: Some(vec![
373 "Peak apparent magnitude".to_string(),
374 "Magnitude decline rate (mag/day)".to_string(),
375 "Maximum color index".to_string(),
376 "Host galaxy stellar mass (log10 M_sun)".to_string(),
377 "Cosmological redshift".to_string(),
378 "Light curve duration (days)".to_string(),
379 "Light curve stretch factor".to_string(),
380 "Host galaxy color excess E(B-V)".to_string(),
381 "Discovery magnitude".to_string(),
382 "Galactic latitude (degrees)".to_string(),
383 ]),
384 description: Some(format!(
385 "Synthetic supernova catalog with {nsupernovae} events"
386 )),
387 metadata: std::collections::HashMap::new(),
388 })
389 }
390 }
391}
392
393pub mod genomics {
395 use super::*;
396
397 pub struct GenomicsDatasets {
399 #[allow(dead_code)]
400 client: ExternalClient,
401 #[allow(dead_code)]
402 cache: DatasetCache,
403 }
404
405 impl GenomicsDatasets {
406 pub fn new() -> Result<Self> {
408 let cachedir = dirs::cache_dir()
409 .ok_or_else(|| {
410 DatasetsError::Other("Could not determine cache directory".to_string())
411 })?
412 .join("scirs2-datasets");
413 Ok(Self {
414 client: ExternalClient::new()?,
415 cache: DatasetCache::new(cachedir),
416 })
417 }
418
419 pub fn load_gene_expression(&self, n_samples: usize, ngenes: usize) -> Result<Dataset> {
421 use scirs2_core::random::{Distribution, LogNormal, Normal};
422
423 let mut rng = thread_rng();
424
425 let mut data = Vec::with_capacity(n_samples * ngenes);
427 let mut phenotypes = Vec::with_capacity(n_samples);
428
429 let condition_effects = [1.0, 2.5, 0.4, 1.8, 0.7]; for sample_idx in 0..n_samples {
433 let condition = sample_idx % condition_effects.len();
434 let base_effect = condition_effects[condition];
435
436 for gene_idx in 0..ngenes {
437 let base_expr = LogNormal::new(5.0, 2.0)
439 .expect("Operation failed")
440 .sample(&mut rng);
441
442 let gene_effect = if gene_idx < ngenes / 10 {
444 base_effect
446 } else {
447 1.0
448 };
449
450 let noise = Normal::new(1.0, 0.2)
452 .expect("Operation failed")
453 .sample(&mut rng);
454
455 let expression: f64 = base_expr * gene_effect * noise;
456 data.push(expression.ln()); }
458
459 phenotypes.push(condition as f64);
460 }
461
462 let data_array = Array2::from_shape_vec((n_samples, ngenes), data)
463 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
464
465 let target = Array1::from_vec(phenotypes);
466
467 let featurenames: Vec<String> = (0..ngenes).map(|i| format!("GENE_{i:06}")).collect();
469
470 Ok(Dataset {
471 data: data_array,
472 target: Some(target),
473 featurenames: Some(featurenames.clone()),
474 targetnames: Some(vec![
475 "Control".to_string(),
476 "Treatment_A".to_string(),
477 "Treatment_B".to_string(),
478 "Disease_X".to_string(),
479 "Disease_Y".to_string(),
480 ]),
481 feature_descriptions: Some(
482 featurenames
483 .iter()
484 .map(|name| format!("Expression level of {name}"))
485 .collect(),
486 ),
487 description: Some(format!(
488 "Synthetic gene expression data: {n_samples} _samples × {ngenes} _genes"
489 )),
490 metadata: std::collections::HashMap::new(),
491 })
492 }
493
494 pub fn load_dnasequences(
496 &self,
497 nsequences: usize,
498 sequence_length: usize,
499 ) -> Result<Dataset> {
500 let mut rng = thread_rng();
501 let nucleotides = ['A', 'T', 'G', 'C'];
502
503 let mut sequences = Vec::new();
504 let mut sequence_types = Vec::with_capacity(nsequences);
505
506 for seq_idx in 0..nsequences {
507 let mut sequence = String::with_capacity(sequence_length);
508
509 let seq_type = seq_idx % 3; for _pos in 0..sequence_length {
513 let nucleotide = match seq_type {
514 0 => {
515 if rng.random::<f64>() < 0.6 {
517 if rng.random::<f64>() < 0.5 {
518 'G'
519 } else {
520 'C'
521 }
522 } else if rng.random::<f64>() < 0.5 {
523 'A'
524 } else {
525 'T'
526 }
527 }
528 1 => {
529 if rng.random::<f64>() < 0.6 {
531 if rng.random::<f64>() < 0.5 {
532 'A'
533 } else {
534 'T'
535 }
536 } else if rng.random::<f64>() < 0.5 {
537 'G'
538 } else {
539 'C'
540 }
541 }
542 _ => {
543 nucleotides[rng.sample(Uniform::new(0, 4).expect("Operation failed"))]
545 }
546 };
547
548 sequence.push(nucleotide);
549 }
550
551 sequences.push(sequence);
552 sequence_types.push(seq_type as f64);
553 }
554
555 let mut data = Vec::new();
557 let k = 3;
558 let kmers = Self::generate_kmers(k);
559
560 for sequence in &sequences {
561 let kmer_counts = Self::count_kmers(sequence, k, &kmers);
562 data.extend(kmer_counts);
563 }
564
565 let n_features = 4_usize.pow(k as u32); let data_array = Array2::from_shape_vec((nsequences, n_features), data)
567 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
568
569 let target = Array1::from_vec(sequence_types);
570
571 Ok(Dataset {
572 data: data_array,
573 target: Some(target),
574 featurenames: Some(kmers.clone()),
575 targetnames: Some(vec![
576 "GC-rich".to_string(),
577 "AT-rich".to_string(),
578 "Random".to_string(),
579 ]),
580 feature_descriptions: Some(
581 kmers
582 .iter()
583 .map(|kmer| format!("Frequency of {k}-mer: {kmer}"))
584 .collect(),
585 ),
586 description: Some(format!(
587 "DNA sequences: {nsequences} seqs × {k}-mer features"
588 )),
589 metadata: std::collections::HashMap::new(),
590 })
591 }
592
593 fn generate_kmers(k: usize) -> Vec<String> {
594 let nucleotides = vec!['A', 'T', 'G', 'C'];
595 let mut kmers = Vec::new();
596
597 fn generate_recursive(
598 current: String,
599 remaining: usize,
600 nucleotides: &[char],
601 kmers: &mut Vec<String>,
602 ) {
603 if remaining == 0 {
604 kmers.push(current);
605 return;
606 }
607
608 for &nucleotide in nucleotides {
609 let mut new_current = current.clone();
610 new_current.push(nucleotide);
611 generate_recursive(new_current, remaining - 1, nucleotides, kmers);
612 }
613 }
614
615 generate_recursive(String::new(), k, &nucleotides, &mut kmers);
616 kmers
617 }
618
619 fn count_kmers(sequence: &str, k: usize, kmers: &[String]) -> Vec<f64> {
620 let mut counts = vec![0.0; kmers.len()];
621 let kmer_to_idx: HashMap<&str, usize> = kmers
622 .iter()
623 .enumerate()
624 .map(|(i, k)| (k.as_str(), i))
625 .collect();
626
627 for i in 0..=sequence.len().saturating_sub(k) {
628 let kmer = &sequence[i..i + k];
629 if let Some(&idx) = kmer_to_idx.get(kmer) {
630 counts[idx] += 1.0;
631 }
632 }
633
634 let total: f64 = counts.iter().sum();
636 if total > 0.0 {
637 for count in &mut counts {
638 *count /= total;
639 }
640 }
641
642 counts
643 }
644 }
645}
646
647pub mod climate {
649 use super::*;
650
651 pub struct ClimateDatasets {
653 #[allow(dead_code)]
654 client: ExternalClient,
655 #[allow(dead_code)]
656 cache: DatasetCache,
657 }
658
659 impl ClimateDatasets {
660 pub fn new() -> Result<Self> {
662 let cachedir = dirs::cache_dir()
663 .ok_or_else(|| {
664 DatasetsError::Other("Could not determine cache directory".to_string())
665 })?
666 .join("scirs2-datasets");
667 Ok(Self {
668 client: ExternalClient::new()?,
669 cache: DatasetCache::new(cachedir),
670 })
671 }
672
673 pub fn load_temperature_timeseries(
675 &self,
676 n_stations: usize,
677 n_years: usize,
678 ) -> Result<Dataset> {
679 use scirs2_core::random::{Distribution, Normal};
680
681 let mut rng = thread_rng();
682 let days_per_year = 365;
683 let total_days = n_years * days_per_year;
684
685 let mut data = Vec::with_capacity(n_stations * 8); let mut climate_zones = Vec::with_capacity(n_stations);
687
688 for station_idx in 0..n_stations {
689 let zone = station_idx % 5; climate_zones.push(zone as f64);
692
693 let (base_temp, temp_amplitude, annual_precip, humidity) = match zone {
695 0 => (25.0, 5.0, 2000.0, 80.0), 1 => (15.0, 15.0, 800.0, 60.0), 2 => (-5.0, 20.0, 400.0, 70.0), 3 => (5.0, 8.0, 200.0, 40.0), _ => (-10.0, 25.0, 300.0, 75.0), };
701
702 let mut temperatures = Vec::with_capacity(total_days);
704 let mut precipitation = Vec::with_capacity(total_days);
705
706 for day in 0..total_days {
707 let year_progress = (day % days_per_year) as f64 / days_per_year as f64;
708 let seasonal_temp = base_temp
709 + temp_amplitude * (year_progress * 2.0 * std::f64::consts::PI).cos();
710
711 let temp_noise = Normal::new(0.0, 2.0).expect("Operation failed");
713 let temp = seasonal_temp + temp_noise.sample(&mut rng);
714 temperatures.push(temp);
715
716 let seasonal_precip_factor = match zone {
718 0 => {
719 1.0 + 0.3
720 * (year_progress * 2.0 * std::f64::consts::PI
721 + std::f64::consts::PI)
722 .cos()
723 }
724 1 => 1.0 + 0.2 * (year_progress * 2.0 * std::f64::consts::PI).sin(),
725 _ => 1.0,
726 };
727
728 let precip = if rng.random::<f64>() < 0.3 {
729 rng.random_range(0.0..20.0) * seasonal_precip_factor
731 } else {
732 0.0
733 };
734 precipitation.push(precip);
735 }
736
737 let mean_temp = temperatures.iter().sum::<f64>() / temperatures.len() as f64;
739 let max_temp = temperatures
740 .iter()
741 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
742 let min_temp = temperatures.iter().fold(f64::INFINITY, |a, &b| a.min(b));
743 let temp_range = max_temp - min_temp;
744
745 let total_precip = precipitation.iter().sum::<f64>();
746 let precip_days = precipitation.iter().filter(|&&p| p > 0.0).count() as f64;
747
748 let avg_humidity = humidity
750 + Normal::new(0.0, 5.0)
751 .expect("Operation failed")
752 .sample(&mut rng);
753 let wind_speed = rng.random_range(2.0..15.0);
754
755 data.extend(vec![
756 mean_temp,
757 temp_range,
758 total_precip,
759 precip_days,
760 avg_humidity,
761 wind_speed,
762 base_temp, annual_precip / 365.0, ]);
765 }
766
767 let data_array = Array2::from_shape_vec((n_stations, 8), data)
768 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
769
770 let target = Array1::from_vec(climate_zones);
771
772 Ok(Dataset {
773 data: data_array,
774 target: Some(target),
775 featurenames: Some(vec![
776 "mean_temperature".to_string(),
777 "temperature_range".to_string(),
778 "annual_precipitation".to_string(),
779 "precipitation_days".to_string(),
780 "avg_humidity".to_string(),
781 "avg_wind_speed".to_string(),
782 "latitude_proxy".to_string(),
783 "daily_precip_avg".to_string(),
784 ]),
785 targetnames: Some(vec![
786 "Tropical".to_string(),
787 "Temperate".to_string(),
788 "Continental".to_string(),
789 "Desert".to_string(),
790 "Arctic".to_string(),
791 ]),
792 feature_descriptions: Some(vec![
793 "Mean annual temperature (°C)".to_string(),
794 "Temperature range (max-min, °C)".to_string(),
795 "Total annual precipitation (mm)".to_string(),
796 "Number of precipitation days per year".to_string(),
797 "Average humidity (%)".to_string(),
798 "Average wind speed (m/s)".to_string(),
799 "Latitude proxy (normalized)".to_string(),
800 "Average daily precipitation (mm/day)".to_string(),
801 ]),
802 description: Some(format!(
803 "Climate data: {n_stations} _stations × {n_years} _years"
804 )),
805 metadata: std::collections::HashMap::new(),
806 })
807 }
808
809 pub fn load_atmospheric_chemistry(&self, nmeasurements: usize) -> Result<Dataset> {
811 use scirs2_core::random::{Distribution, LogNormal, Normal};
812
813 let mut rng = thread_rng();
814
815 let mut data = Vec::with_capacity(nmeasurements * 12);
816 let mut air_quality_index = Vec::with_capacity(nmeasurements);
817
818 for _ in 0..nmeasurements {
819 let base_pollution = rng.random_range(0.0..1.0);
821
822 let pm25: f64 = LogNormal::new(2.0 + base_pollution, 0.5)
824 .expect("Failed to create array")
825 .sample(&mut rng);
826 let pm10 = pm25 * rng.random_range(1.5..2.5);
827 let no2 = LogNormal::new(3.0 + base_pollution * 0.5, 0.3)
828 .expect("Failed to create array")
829 .sample(&mut rng);
830 let so2 = LogNormal::new(1.0 + base_pollution * 0.3, 0.4)
831 .expect("Failed to create array")
832 .sample(&mut rng);
833 let o3 = LogNormal::new(4.0 - base_pollution * 0.2, 0.2)
834 .expect("Failed to create array")
835 .sample(&mut rng);
836 let co = LogNormal::new(0.5 + base_pollution * 0.4, 0.3)
837 .expect("Failed to create array")
838 .sample(&mut rng);
839
840 let temperature = Normal::new(20.0, 10.0)
842 .expect("Operation failed")
843 .sample(&mut rng);
844 let humidity = rng.random_range(30.0..90.0);
845 let wind_speed = rng.random_range(0.5..12.0);
846 let pressure = Normal::new(1013.0, 15.0)
847 .expect("Operation failed")
848 .sample(&mut rng);
849
850 let visibility = (50.0 - pm25.ln() * 5.0).max(1.0);
852 let uv_index = rng.random_range(0.0..12.0);
853
854 data.extend(vec![
855 pm25,
856 pm10,
857 no2,
858 so2,
859 o3,
860 co,
861 temperature,
862 humidity,
863 wind_speed,
864 pressure,
865 visibility,
866 uv_index,
867 ]);
868
869 let aqi = Self::calculate_aqi(pm25, pm10, no2, so2, o3, co);
871 air_quality_index.push(aqi);
872 }
873
874 let data_array = Array2::from_shape_vec((nmeasurements, 12), data)
875 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
876
877 let target = Array1::from_vec(air_quality_index);
878
879 Ok(Dataset {
880 data: data_array,
881 target: Some(target),
882 featurenames: Some(vec![
883 "pm2_5".to_string(),
884 "pm10".to_string(),
885 "no2".to_string(),
886 "so2".to_string(),
887 "o3".to_string(),
888 "co".to_string(),
889 "temperature".to_string(),
890 "humidity".to_string(),
891 "wind_speed".to_string(),
892 "pressure".to_string(),
893 "visibility".to_string(),
894 "uv_index".to_string(),
895 ]),
896 targetnames: None,
897 feature_descriptions: Some(vec![
898 "PM2.5 concentration (µg/m³)".to_string(),
899 "PM10 concentration (µg/m³)".to_string(),
900 "NO2 concentration (µg/m³)".to_string(),
901 "SO2 concentration (µg/m³)".to_string(),
902 "O3 concentration (µg/m³)".to_string(),
903 "CO concentration (µg/m³)".to_string(),
904 "Temperature (°C)".to_string(),
905 "Relative humidity (%)".to_string(),
906 "Wind speed (m/s)".to_string(),
907 "Atmospheric pressure (hPa)".to_string(),
908 "Visibility (km)".to_string(),
909 "UV index".to_string(),
910 ]),
911 description: Some(format!(
912 "Atmospheric chemistry _measurements: {nmeasurements} samples"
913 )),
914 metadata: std::collections::HashMap::new(),
915 })
916 }
917
918 #[allow(clippy::too_many_arguments)]
919 fn calculate_aqi(pm25: f64, pm10: f64, no2: f64, so2: f64, o3: f64, co: f64) -> f64 {
920 let pm25_aqi = (pm25 / 35.0 * 100.0).min(300.0);
922 let pm10_aqi = (pm10 / 150.0 * 100.0).min(300.0);
923 let no2_aqi = (no2 / 100.0 * 100.0).min(300.0);
924 let so2_aqi = (so2 / 75.0 * 100.0).min(300.0);
925 let o3_aqi = (o3 / 120.0 * 100.0).min(300.0);
926 let co_aqi = (co / 9.0 * 100.0).min(300.0);
927
928 [pm25_aqi, pm10_aqi, no2_aqi, so2_aqi, o3_aqi, co_aqi]
930 .iter()
931 .fold(0.0f64, |a, &b| a.max(b))
932 }
933 }
934}
935
936pub mod convenience {
938 use super::astronomy::StellarDatasets;
939 use super::climate::ClimateDatasets;
940 use super::genomics::GenomicsDatasets;
941 use super::*;
942
943 pub fn load_stellar_classification() -> Result<Dataset> {
945 let datasets = StellarDatasets::new()?;
946 datasets.load_hipparcos_catalog()
947 }
948
949 pub fn load_exoplanets() -> Result<Dataset> {
951 let datasets = StellarDatasets::new()?;
952 datasets.load_exoplanet_catalog()
953 }
954
955 pub fn load_gene_expression(
957 n_samples: Option<usize>,
958 ngenes: Option<usize>,
959 ) -> Result<Dataset> {
960 let datasets = GenomicsDatasets::new()?;
961 datasets.load_gene_expression(n_samples.unwrap_or(200), ngenes.unwrap_or(1000))
962 }
963
964 pub fn load_climate_data(
966 _n_stations: Option<usize>,
967 n_years: Option<usize>,
968 ) -> Result<Dataset> {
969 let datasets = ClimateDatasets::new()?;
970 datasets.load_temperature_timeseries(_n_stations.unwrap_or(100), n_years.unwrap_or(10))
971 }
972
973 pub fn load_atmospheric_chemistry(_nmeasurements: Option<usize>) -> Result<Dataset> {
975 let datasets = ClimateDatasets::new()?;
976 datasets.load_atmospheric_chemistry(_nmeasurements.unwrap_or(1000))
977 }
978
979 pub fn list_domain_datasets() -> Vec<(&'static str, &'static str)> {
981 vec![
982 ("astronomy", "stellar_classification"),
983 ("astronomy", "exoplanets"),
984 ("astronomy", "supernovae"),
985 ("astronomy", "gaia_dr3"),
986 ("genomics", "gene_expression"),
987 ("genomics", "dnasequences"),
988 ("climate", "temperature_timeseries"),
989 ("climate", "atmospheric_chemistry"),
990 ]
991 }
992}
993
994#[cfg(test)]
995mod tests {
996 use super::convenience::*;
997 use scirs2_core::random::Uniform;
998
999 #[test]
1000 fn test_load_stellar_classification() {
1001 let dataset = load_stellar_classification().expect("Operation failed");
1002 assert!(dataset.n_samples() > 1000);
1003 assert_eq!(dataset.n_features(), 8);
1004 assert!(dataset.target.is_some());
1005 }
1006
1007 #[test]
1008 fn test_load_exoplanets() {
1009 let dataset = load_exoplanets().expect("Operation failed");
1010 assert!(dataset.n_samples() > 100);
1011 assert_eq!(dataset.n_features(), 6);
1012 assert!(dataset.target.is_some());
1013 }
1014
1015 #[test]
1016 fn test_load_gene_expression() {
1017 let dataset = load_gene_expression(Some(50), Some(100)).expect("Operation failed");
1018 assert_eq!(dataset.n_samples(), 50);
1019 assert_eq!(dataset.n_features(), 100);
1020 assert!(dataset.target.is_some());
1021 }
1022
1023 #[test]
1024 fn test_load_climate_data() {
1025 let dataset = load_climate_data(Some(20), Some(5)).expect("Operation failed");
1026 assert_eq!(dataset.n_samples(), 20);
1027 assert_eq!(dataset.n_features(), 8);
1028 assert!(dataset.target.is_some());
1029 }
1030
1031 #[test]
1032 fn test_load_atmospheric_chemistry() {
1033 let dataset = load_atmospheric_chemistry(Some(100)).expect("Operation failed");
1034 assert_eq!(dataset.n_samples(), 100);
1035 assert_eq!(dataset.n_features(), 12);
1036 assert!(dataset.target.is_some());
1037 }
1038
1039 #[test]
1040 fn test_list_domain_datasets() {
1041 let datasets = list_domain_datasets();
1042 assert!(!datasets.is_empty());
1043 assert!(datasets.iter().any(|(domain_, _)| *domain_ == "astronomy"));
1044 assert!(datasets.iter().any(|(domain_, _)| *domain_ == "genomics"));
1045 assert!(datasets.iter().any(|(domain_, _)| *domain_ == "climate"));
1046 }
1047}