use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub trait TrainingDataProvider: Send + Sync {
fn get_batch(&self, batch_size: usize) -> Result<TrainingBatch>;
fn get_full_dataset(&self) -> Result<TrainingDataset>;
fn get_dataset_info(&self) -> DatasetInfo;
}
pub trait ValidationDataProvider: Send + Sync {
fn get_validation_set(&self) -> Result<ValidationDataset>;
fn get_test_set(&self) -> Result<TestDataset>;
}
pub trait DownstreamTaskEvaluator: Send + Sync {
fn evaluate(&self, embeddings: &[Vec<f32>], labels: &[usize]) -> Result<f64>;
fn get_task_name(&self) -> &str;
fn get_evaluation_config(&self) -> &TaskEvaluationConfig;
}
#[derive(Debug, Clone)]
pub struct TrainingBatch {
pub inputs: Vec<Vec<f32>>,
pub targets: Vec<Vec<f32>>,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct TrainingDataset {
pub samples: Vec<TrainingBatch>,
pub statistics: DatasetStatistics,
}
#[derive(Debug, Clone)]
pub struct ValidationDataset {
pub samples: Vec<TrainingBatch>,
pub ground_truth: HashMap<String, Vec<f64>>,
}
#[derive(Debug, Clone)]
pub struct TestDataset {
pub samples: Vec<TrainingBatch>,
pub reference_embeddings: Vec<Vec<f32>>,
}
#[derive(Debug, Clone)]
pub struct DatasetInfo {
pub num_samples: usize,
pub input_dim: usize,
pub output_dim: usize,
pub data_type: DataType,
pub domain: String,
}
#[derive(Debug, Clone, PartialEq)]
pub enum DataType {
Text,
Image,
Audio,
Video,
Tabular,
Graph,
TimeSeries,
MultiModal,
}
#[derive(Debug, Clone)]
pub struct DatasetStatistics {
pub mean: Vec<f64>,
pub std: Vec<f64>,
pub min: Vec<f64>,
pub max: Vec<f64>,
pub skewness: Vec<f64>,
pub kurtosis: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct TaskEvaluationConfig {
pub task_type: TaskType,
pub metrics: Vec<String>,
pub cv_folds: usize,
pub random_seed: u64,
}
#[derive(Debug, Clone, PartialEq)]
pub enum TaskType {
Classification,
Regression,
Clustering,
SimilaritySearch,
Retrieval,
Recommendation,
}
#[derive(Debug, Clone)]
pub struct EvaluationBudget {
pub max_evaluations: usize,
pub max_time_minutes: usize,
pub max_gpu_hours: f64,
pub max_cpu_hours: f64,
pub max_memory_gb: f64,
}
#[derive(Debug, Clone)]
pub struct CrossValidationConfig {
pub n_folds: usize,
pub stratified: bool,
pub random_seed: u64,
}