use crate::{Dataset, DatasetUtilsExt};
use scirs2_core::random::{rngs::StdRng, Rng, RngExt, SeedableRng};
use std::marker::PhantomData;
use tenflowers_core::{Result, Tensor, TensorError};
#[derive(Debug, Clone)]
pub struct ModernMLConfig {
pub seed: u64,
pub feature_dim: usize,
pub noise_level: f32,
}
impl Default for ModernMLConfig {
fn default() -> Self {
Self {
seed: 42,
feature_dim: 128,
noise_level: 0.1,
}
}
}
pub struct FewShotDataset<T> {
episodes: Vec<Episode<T>>,
current_episode: usize,
_phantom: PhantomData<T>,
}
#[derive(Debug, Clone)]
pub struct Episode<T> {
pub support_set: Vec<(Tensor<T>, usize)>, pub query_set: Vec<(Tensor<T>, usize)>, pub n_way: usize,
pub k_shot: usize,
}
impl<T> FewShotDataset<T>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::Float
+ Send
+ Sync
+ 'static,
{
pub fn new(
num_episodes: usize,
n_way: usize,
k_shot: usize,
query_per_class: usize,
config: &ModernMLConfig,
) -> Result<Self> {
let mut rng = StdRng::seed_from_u64(config.seed);
let mut episodes = Vec::with_capacity(num_episodes);
for episode_idx in 0..num_episodes {
let mut support_set = Vec::new();
let mut query_set = Vec::new();
let mut class_prototypes = Vec::with_capacity(n_way);
for _ in 0..n_way {
let prototype = generate_random_vector(&mut rng, config.feature_dim, T::one());
class_prototypes.push(prototype);
}
for (class_id, prototype) in class_prototypes.iter().enumerate().take(n_way) {
for _ in 0..k_shot {
let features = add_noise_to_vector(prototype, &mut rng, config.noise_level)?;
support_set.push((features, class_id));
}
}
for (class_id, prototype) in class_prototypes.iter().enumerate().take(n_way) {
for _ in 0..query_per_class {
let features = add_noise_to_vector(prototype, &mut rng, config.noise_level)?;
query_set.push((features, class_id));
}
}
episodes.push(Episode {
support_set,
query_set,
n_way,
k_shot,
});
}
Ok(Self {
episodes,
current_episode: 0,
_phantom: PhantomData,
})
}
pub fn next_episode(&mut self) -> Option<&Episode<T>> {
if self.current_episode < self.episodes.len() {
let episode = &self.episodes[self.current_episode];
self.current_episode += 1;
Some(episode)
} else {
None
}
}
pub fn reset(&mut self) {
self.current_episode = 0;
}
pub fn num_episodes(&self) -> usize {
self.episodes.len()
}
}
pub struct ContrastiveLearningDataset<T> {
positive_pairs: Vec<(Tensor<T>, Tensor<T>)>,
negative_pairs: Vec<(Tensor<T>, Tensor<T>)>,
_phantom: PhantomData<T>,
}
impl<T> ContrastiveLearningDataset<T>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::Float
+ Send
+ Sync
+ 'static,
{
pub fn new(
num_positive_pairs: usize,
num_negative_pairs: usize,
config: &ModernMLConfig,
) -> Result<Self> {
let mut rng = StdRng::seed_from_u64(config.seed);
let mut positive_pairs = Vec::with_capacity(num_positive_pairs);
let mut negative_pairs = Vec::with_capacity(num_negative_pairs);
for _ in 0..num_positive_pairs {
let anchor = generate_random_vector(&mut rng, config.feature_dim, T::one());
let positive = add_noise_to_vector(&anchor, &mut rng, config.noise_level)?;
positive_pairs.push((anchor, positive));
}
for _ in 0..num_negative_pairs {
let anchor = generate_random_vector(&mut rng, config.feature_dim, T::one());
let negative = generate_random_vector(&mut rng, config.feature_dim, T::one());
negative_pairs.push((anchor, negative));
}
Ok(Self {
positive_pairs,
negative_pairs,
_phantom: PhantomData,
})
}
pub fn positive_pairs(&self) -> &[(Tensor<T>, Tensor<T>)] {
&self.positive_pairs
}
pub fn negative_pairs(&self) -> &[(Tensor<T>, Tensor<T>)] {
&self.negative_pairs
}
pub fn get_positive_pair(&self, index: usize) -> Option<&(Tensor<T>, Tensor<T>)> {
self.positive_pairs.get(index)
}
pub fn get_negative_pair(&self, index: usize) -> Option<&(Tensor<T>, Tensor<T>)> {
self.negative_pairs.get(index)
}
}
pub struct SelfSupervisedDataset<T> {
original_data: Vec<Tensor<T>>,
augmented_data: Vec<Vec<Tensor<T>>>, _phantom: PhantomData<T>,
}
impl<T> SelfSupervisedDataset<T>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::Float
+ Send
+ Sync
+ 'static,
{
pub fn new(
num_samples: usize,
num_augmentations: usize,
config: &ModernMLConfig,
) -> Result<Self> {
let mut rng = StdRng::seed_from_u64(config.seed);
let mut original_data = Vec::with_capacity(num_samples);
let mut augmented_data = Vec::with_capacity(num_samples);
for _ in 0..num_samples {
let original = generate_random_vector(&mut rng, config.feature_dim, T::one());
original_data.push(original.clone());
let mut augmentations = Vec::with_capacity(num_augmentations);
for _ in 0..num_augmentations {
let augmentation = apply_augmentation(&original, &mut rng, config.noise_level)?;
augmentations.push(augmentation);
}
augmented_data.push(augmentations);
}
Ok(Self {
original_data,
augmented_data,
_phantom: PhantomData,
})
}
pub fn get_original(&self, index: usize) -> Option<&Tensor<T>> {
self.original_data.get(index)
}
pub fn get_augmentations(&self, index: usize) -> Option<&[Tensor<T>]> {
self.augmented_data.get(index).map(|augs| augs.as_slice())
}
pub fn get_augmentation(&self, sample_index: usize, aug_index: usize) -> Option<&Tensor<T>> {
self.augmented_data
.get(sample_index)
.and_then(|augs| augs.get(aug_index))
}
pub fn len(&self) -> usize {
self.original_data.len()
}
pub fn is_empty(&self) -> bool {
self.original_data.is_empty()
}
}
pub struct MetaLearningDataset<T> {
tasks: Vec<TaskDataset<T>>,
_phantom: PhantomData<T>,
}
#[derive(Debug, Clone)]
pub struct TaskDataset<T> {
pub train_data: Vec<(Tensor<T>, Tensor<T>)>,
pub test_data: Vec<(Tensor<T>, Tensor<T>)>,
pub task_id: usize,
}
impl<T> MetaLearningDataset<T>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::Float
+ Send
+ Sync
+ 'static,
{
pub fn new(
num_tasks: usize,
samples_per_task: usize,
test_ratio: f64,
config: &ModernMLConfig,
) -> Result<Self> {
let mut rng = StdRng::seed_from_u64(config.seed);
let mut tasks = Vec::with_capacity(num_tasks);
let num_test = (samples_per_task as f64 * test_ratio) as usize;
let num_train = samples_per_task - num_test;
for task_id in 0..num_tasks {
let mut train_data = Vec::with_capacity(num_train);
let mut test_data = Vec::with_capacity(num_test);
let task_weight = generate_random_vector(&mut rng, config.feature_dim, T::one());
let task_bias = rng.random::<f32>() * 2.0 - 1.0;
for _ in 0..num_train {
let features = generate_random_vector(&mut rng, config.feature_dim, T::one());
let label = compute_synthetic_label(&features, &task_weight, task_bias)?;
train_data.push((features, label));
}
for _ in 0..num_test {
let features = generate_random_vector(&mut rng, config.feature_dim, T::one());
let label = compute_synthetic_label(&features, &task_weight, task_bias)?;
test_data.push((features, label));
}
tasks.push(TaskDataset {
train_data,
test_data,
task_id,
});
}
Ok(Self {
tasks,
_phantom: PhantomData,
})
}
pub fn tasks(&self) -> &[TaskDataset<T>] {
&self.tasks
}
pub fn get_task(&self, index: usize) -> Option<&TaskDataset<T>> {
self.tasks.get(index)
}
pub fn num_tasks(&self) -> usize {
self.tasks.len()
}
}
fn generate_random_vector<T, R>(rng: &mut R, dim: usize, scale: T) -> Tensor<T>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::Float
+ Send
+ Sync
+ 'static,
R: Rng,
{
let data: Vec<T> = (0..dim)
.map(|_| {
let val = rng.random::<f32>() * 2.0 - 1.0; T::from(val).unwrap_or(T::zero()) * scale
})
.collect();
Tensor::from_vec(data, &[dim]).unwrap_or_else(|_| Tensor::zeros(&[dim]))
}
fn add_noise_to_vector<T, R>(vector: &Tensor<T>, rng: &mut R, noise_level: f32) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::Float
+ Send
+ Sync
+ 'static,
R: Rng,
{
let data = vector.as_slice().ok_or_else(|| {
TensorError::invalid_operation_simple("Cannot access tensor data".to_string())
})?;
let noisy_data: Vec<T> = data
.iter()
.map(|&val| {
let noise = rng.random::<f32>() * noise_level * 2.0 - noise_level;
val + T::from(noise).unwrap_or(T::zero())
})
.collect();
Tensor::from_vec(noisy_data, vector.shape().dims())
}
fn apply_augmentation<T, R>(vector: &Tensor<T>, rng: &mut R, aug_strength: f32) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::Float
+ Send
+ Sync
+ 'static,
R: Rng,
{
let data = vector.as_slice().ok_or_else(|| {
TensorError::invalid_operation_simple("Cannot access tensor data".to_string())
})?;
let scale = 1.0 + (rng.random::<f32>() - 0.5) * aug_strength;
let shift = (rng.random::<f32>() - 0.5) * aug_strength;
let augmented_data: Vec<T> = data
.iter()
.map(|&val| {
let scaled = val * T::from(scale).unwrap_or(T::one());
let shifted = scaled + T::from(shift).unwrap_or(T::zero());
shifted
})
.collect();
Tensor::from_vec(augmented_data, vector.shape().dims())
}
fn compute_synthetic_label<T>(
features: &Tensor<T>,
weights: &Tensor<T>,
bias: f32,
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::Float
+ Send
+ Sync
+ 'static,
{
let feat_data = features.as_slice().ok_or_else(|| {
TensorError::invalid_operation_simple("Cannot access features tensor data".to_string())
})?;
let weight_data = weights.as_slice().ok_or_else(|| {
TensorError::invalid_operation_simple("Cannot access weights tensor data".to_string())
})?;
if feat_data.len() != weight_data.len() {
return Err(TensorError::invalid_shape_simple(format!(
"Feature and weight dimensions mismatch: {} vs {}",
feat_data.len(),
weight_data.len()
)));
}
let dot_product: T = feat_data
.iter()
.zip(weight_data.iter())
.map(|(&f, &w)| f * w)
.fold(T::zero(), |acc, val| acc + val);
let linear_output = dot_product + T::from(bias).unwrap_or(T::zero());
let activated = linear_output;
Tensor::from_vec(vec![activated], &[])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_few_shot_dataset_creation() {
let config = ModernMLConfig::default();
let dataset = FewShotDataset::<f32>::new(5, 3, 2, 1, &config)
.expect("test: operation should succeed");
assert_eq!(dataset.num_episodes(), 5);
if let Some(episode) = dataset.episodes.first() {
assert_eq!(episode.n_way, 3);
assert_eq!(episode.k_shot, 2);
assert_eq!(episode.support_set.len(), 6); assert_eq!(episode.query_set.len(), 3); }
}
#[test]
fn test_contrastive_learning_dataset() {
let config = ModernMLConfig::default();
let dataset = ContrastiveLearningDataset::<f32>::new(10, 15, &config)
.expect("test: operation should succeed");
assert_eq!(dataset.positive_pairs().len(), 10);
assert_eq!(dataset.negative_pairs().len(), 15);
}
#[test]
fn test_self_supervised_dataset() {
let config = ModernMLConfig::default();
let dataset = SelfSupervisedDataset::<f32>::new(5, 3, &config)
.expect("test: operation should succeed");
assert_eq!(dataset.len(), 5);
assert!(!dataset.is_empty());
for i in 0..dataset.len() {
assert_eq!(
dataset
.get_augmentations(i)
.expect("test: operation should succeed")
.len(),
3
);
}
}
#[test]
fn test_meta_learning_dataset() {
let config = ModernMLConfig::default();
let dataset = MetaLearningDataset::<f32>::new(3, 20, 0.2, &config)
.expect("test: operation should succeed");
assert_eq!(dataset.num_tasks(), 3);
for task in dataset.tasks() {
assert_eq!(task.train_data.len(), 16); assert_eq!(task.test_data.len(), 4); }
}
#[test]
fn test_vector_generation() {
use scirs2_core::random::{rngs::StdRng, SeedableRng};
let mut rng = StdRng::seed_from_u64(42);
let vector = generate_random_vector(&mut rng, 10, 1.0f32);
assert_eq!(vector.shape().dims(), &[10]);
}
#[test]
fn test_noise_addition() {
use scirs2_core::random::{rngs::StdRng, SeedableRng};
let original = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3])
.expect("test: tensor creation should succeed");
let mut rng = StdRng::seed_from_u64(42);
let noisy =
add_noise_to_vector(&original, &mut rng, 0.1).expect("test: operation should succeed");
assert_eq!(noisy.shape().dims(), &[3]);
let orig_data = original.as_slice().expect("tensor should be contiguous");
let noisy_data = noisy.as_slice().expect("tensor should be contiguous");
for (orig, noise) in orig_data.iter().zip(noisy_data.iter()) {
assert!((orig - noise).abs() <= 0.2); }
}
}