use crate::{Dataset, Result};
use scirs2_core::random::rngs::StdRng;
use scirs2_core::random::{Rng, RngExt, SeedableRng};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tenflowers_core::{Tensor, TensorError};
static GLOBAL_SEED_MANAGER: std::sync::OnceLock<Arc<Mutex<SeedManager>>> =
std::sync::OnceLock::new();
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SeedManager {
master_seed: u64,
component_seeds: HashMap<String, u64>,
operation_counter: u64,
}
impl SeedManager {
pub fn new(master_seed: u64) -> Self {
Self {
master_seed,
component_seeds: HashMap::new(),
operation_counter: 0,
}
}
pub fn master_seed(&self) -> u64 {
self.master_seed
}
pub fn get_component_seed(&mut self, component: &str) -> u64 {
if let Some(&seed) = self.component_seeds.get(component) {
seed
} else {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
use std::hash::{Hash, Hasher};
self.master_seed.hash(&mut hasher);
component.hash(&mut hasher);
let seed = hasher.finish();
self.component_seeds.insert(component.to_string(), seed);
seed
}
}
pub fn next_operation_seed(&mut self) -> u64 {
self.operation_counter += 1;
let mut hasher = std::collections::hash_map::DefaultHasher::new();
use std::hash::{Hash, Hasher};
self.master_seed.hash(&mut hasher);
self.operation_counter.hash(&mut hasher);
hasher.finish()
}
pub fn create_rng(&mut self, component: &str) -> StdRng {
let seed = self.get_component_seed(component);
StdRng::seed_from_u64(seed)
}
pub fn set_global(manager: SeedManager) {
let _ = GLOBAL_SEED_MANAGER.set(Arc::new(Mutex::new(manager)));
}
pub fn global() -> Arc<Mutex<SeedManager>> {
GLOBAL_SEED_MANAGER
.get_or_init(|| Arc::new(Mutex::new(SeedManager::new(42))))
.clone()
}
pub fn reset(&mut self) {
self.component_seeds.clear();
self.operation_counter = 0;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnvironmentInfo {
pub rust_version: String,
pub os: String,
pub arch: String,
pub num_cpus: usize,
pub timestamp: u64,
pub env_vars: HashMap<String, String>,
pub seed_info: SeedInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SeedInfo {
pub master_seed: u64,
pub component_seeds: HashMap<String, u64>,
}
impl EnvironmentInfo {
pub fn capture(seed_manager: &SeedManager) -> Self {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let mut env_vars = HashMap::new();
for var in ["RUST_LOG", "CARGO_TARGET_DIR", "RUSTFLAGS"] {
if let Ok(value) = std::env::var(var) {
env_vars.insert(var.to_string(), value);
}
}
Self {
rust_version: "unknown".to_string(), os: std::env::consts::OS.to_string(),
arch: std::env::consts::ARCH.to_string(),
num_cpus: num_cpus::get(),
timestamp,
env_vars,
seed_info: SeedInfo {
master_seed: seed_manager.master_seed,
component_seeds: seed_manager.component_seeds.clone(),
},
}
}
}
#[derive(Debug)]
pub struct DeterministicDataset<T, D> {
dataset: D,
indices: Vec<usize>,
_phantom: std::marker::PhantomData<T>,
}
impl<T, D> DeterministicDataset<T, D>
where
D: Dataset<T>,
T: Clone + Default + Send + Sync + 'static,
{
pub fn new(dataset: D, seed: u64) -> Self {
let len = dataset.len();
let mut indices: Vec<usize> = (0..len).collect();
let mut rng = StdRng::seed_from_u64(seed);
Self::fisher_yates_shuffle(&mut indices, &mut rng);
Self {
dataset,
indices,
_phantom: std::marker::PhantomData,
}
}
pub fn sequential(dataset: D) -> Self {
let len = dataset.len();
let indices: Vec<usize> = (0..len).collect();
Self {
dataset,
indices,
_phantom: std::marker::PhantomData,
}
}
pub fn reverse(dataset: D) -> Self {
let len = dataset.len();
let indices: Vec<usize> = (0..len).rev().collect();
Self {
dataset,
indices,
_phantom: std::marker::PhantomData,
}
}
pub fn inner(&self) -> &D {
&self.dataset
}
pub fn indices(&self) -> &[usize] {
&self.indices
}
pub fn reshuffle(&mut self, seed: u64) {
let mut rng = StdRng::seed_from_u64(seed);
Self::fisher_yates_shuffle(&mut self.indices, &mut rng);
}
fn fisher_yates_shuffle<R: Rng>(indices: &mut [usize], rng: &mut R) {
for i in (1..indices.len()).rev() {
let j = rng.random_range(0..i);
indices.swap(i, j);
}
}
}
impl<T, D> Dataset<T> for DeterministicDataset<T, D>
where
D: Dataset<T>,
T: Clone + Default + Send + Sync + 'static,
{
fn len(&self) -> usize {
self.dataset.len()
}
fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
if index >= self.indices.len() {
return Err(TensorError::invalid_argument(format!(
"Index {} out of bounds for dataset of length {}",
index,
self.indices.len()
)));
}
let actual_index = self.indices[index];
self.dataset.get(actual_index)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExperimentConfig {
pub name: String,
pub seed: u64,
pub dataset_config: DatasetConfig,
pub environment: EnvironmentInfo,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetConfig {
pub ordering: OrderingStrategy,
pub sampling: SamplingConfig,
pub transforms: Vec<TransformConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum OrderingStrategy {
Sequential,
Reverse,
Shuffled { seed: u64 },
Custom { indices: Vec<usize> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamplingConfig {
pub strategy: String,
pub seed: u64,
pub parameters: HashMap<String, f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransformConfig {
pub name: String,
pub seed: u64,
pub parameters: HashMap<String, serde_json::Value>,
}
pub struct DeterministicOrdering;
impl DeterministicOrdering {
pub fn create_indices(len: usize, strategy: &OrderingStrategy) -> Vec<usize> {
match strategy {
OrderingStrategy::Sequential => (0..len).collect(),
OrderingStrategy::Reverse => (0..len).rev().collect(),
OrderingStrategy::Shuffled { seed } => {
let mut indices: Vec<usize> = (0..len).collect();
let mut rng = StdRng::seed_from_u64(*seed);
Self::shuffle_indices(&mut indices, &mut rng);
indices
}
OrderingStrategy::Custom { indices } => {
indices
.iter()
.map(|&i| i.min(len.saturating_sub(1)))
.collect()
}
}
}
pub fn shuffle_indices<R: Rng>(indices: &mut [usize], rng: &mut R) {
for i in (1..indices.len()).rev() {
let j = rng.random_range(0..i);
indices.swap(i, j);
}
}
pub fn create_stratified_indices_f32(
dataset: &dyn Dataset<f32>,
seed: u64,
num_classes: usize,
) -> Result<Vec<usize>> {
let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); num_classes];
for i in 0..dataset.len() {
let (_, labels) = dataset.get(i)?;
let class = if labels.is_scalar() {
labels.get(&[]).unwrap_or(0.0) as usize
} else if let Some(slice) = labels.as_slice() {
slice.first().copied().unwrap_or(0.0) as usize
} else {
0
};
if class < num_classes {
class_indices[class].push(i);
}
}
let mut rng = StdRng::seed_from_u64(seed);
let mut result = Vec::new();
for class_samples in &mut class_indices {
Self::shuffle_indices(class_samples, &mut rng);
result.extend_from_slice(class_samples);
}
Ok(result)
}
}
pub trait ReproducibilityExt<T>: Dataset<T> + Sized
where
T: Clone + Default + Send + Sync + 'static,
{
fn deterministic(self, seed: u64) -> DeterministicDataset<T, Self> {
DeterministicDataset::new(self, seed)
}
fn sequential(self) -> DeterministicDataset<T, Self> {
DeterministicDataset::sequential(self)
}
fn reverse(self) -> DeterministicDataset<T, Self> {
DeterministicDataset::reverse(self)
}
}
impl<T, D: Dataset<T>> ReproducibilityExt<T> for D where T: Clone + Default + Send + Sync + 'static {}
#[derive(Debug)]
pub struct ExperimentTracker {
config: ExperimentConfig,
start_time: std::time::Instant,
operations: Vec<OperationRecord>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OperationRecord {
pub name: String,
pub timestamp: u64,
pub duration_ms: u64,
pub seed: u64,
pub metadata: HashMap<String, String>,
}
impl ExperimentTracker {
pub fn new(config: ExperimentConfig) -> Self {
Self {
config,
start_time: std::time::Instant::now(),
operations: Vec::new(),
}
}
pub fn record_operation(
&mut self,
name: String,
duration: std::time::Duration,
seed: u64,
metadata: HashMap<String, String>,
) {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let record = OperationRecord {
name,
timestamp,
duration_ms: duration.as_millis() as u64,
seed,
metadata,
};
self.operations.push(record);
}
pub fn config(&self) -> &ExperimentConfig {
&self.config
}
pub fn operations(&self) -> &[OperationRecord] {
&self.operations
}
pub fn save_to_file<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
let experiment_data = ExperimentData {
config: self.config.clone(),
operations: self.operations.clone(),
total_duration_ms: self.start_time.elapsed().as_millis() as u64,
};
let json_data = serde_json::to_string_pretty(&experiment_data).map_err(|e| {
TensorError::invalid_argument(format!("Failed to serialize experiment data: {e}"))
})?;
std::fs::write(path, json_data).map_err(|e| {
TensorError::invalid_argument(format!("Failed to write experiment file: {e}"))
})?;
Ok(())
}
pub fn load_from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let json_data = std::fs::read_to_string(path).map_err(|e| {
TensorError::invalid_argument(format!("Failed to read experiment file: {e}"))
})?;
let experiment_data: ExperimentData = serde_json::from_str(&json_data).map_err(|e| {
TensorError::invalid_argument(format!("Failed to parse experiment JSON: {e}"))
})?;
Ok(Self {
config: experiment_data.config,
start_time: std::time::Instant::now(), operations: experiment_data.operations,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ExperimentData {
config: ExperimentConfig,
operations: Vec<OperationRecord>,
total_duration_ms: u64,
}
pub struct DeterministicOps;
impl DeterministicOps {
pub fn set_global_seed(seed: u64) {
SeedManager::set_global(SeedManager::new(seed));
}
pub fn get_rng(component: &str) -> StdRng {
let manager = SeedManager::global();
let mut manager = manager.lock().unwrap_or_else(|e| e.into_inner());
manager.create_rng(component)
}
pub fn next_operation_seed() -> u64 {
let manager = SeedManager::global();
let mut manager = manager.lock().unwrap_or_else(|e| e.into_inner());
manager.next_operation_seed()
}
pub fn capture_environment() -> EnvironmentInfo {
let manager = SeedManager::global();
let manager = manager.lock().unwrap_or_else(|e| e.into_inner());
EnvironmentInfo::capture(&manager)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TensorDataset;
use tempfile::TempDir;
#[test]
fn test_seed_manager() {
let mut manager = SeedManager::new(42);
assert_eq!(manager.master_seed(), 42);
let seed1 = manager.get_component_seed("test");
let seed2 = manager.get_component_seed("test");
assert_eq!(seed1, seed2);
let seed3 = manager.get_component_seed("other");
assert_ne!(seed1, seed3);
let op1 = manager.next_operation_seed();
let op2 = manager.next_operation_seed();
assert_ne!(op1, op2);
}
#[test]
fn test_deterministic_dataset() {
let features_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let labels_data = vec![0.0, 1.0, 0.0];
let features =
Tensor::from_vec(features_data, &[3, 2]).expect("test: tensor creation should succeed");
let labels =
Tensor::from_vec(labels_data, &[3]).expect("test: tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let det_dataset = DeterministicDataset::new(dataset, 42);
assert_eq!(det_dataset.len(), 3);
let det_dataset2 = DeterministicDataset::new(det_dataset.inner().clone(), 42);
assert_eq!(det_dataset.indices(), det_dataset2.indices());
let det_dataset3 = DeterministicDataset::new(det_dataset.inner().clone(), 123);
assert_ne!(det_dataset.indices(), det_dataset3.indices());
}
#[test]
fn test_ordering_strategies() {
let len = 5;
let seq_indices = DeterministicOrdering::create_indices(len, &OrderingStrategy::Sequential);
assert_eq!(seq_indices, vec![0, 1, 2, 3, 4]);
let rev_indices = DeterministicOrdering::create_indices(len, &OrderingStrategy::Reverse);
assert_eq!(rev_indices, vec![4, 3, 2, 1, 0]);
let shuffled1 =
DeterministicOrdering::create_indices(len, &OrderingStrategy::Shuffled { seed: 42 });
let shuffled2 =
DeterministicOrdering::create_indices(len, &OrderingStrategy::Shuffled { seed: 42 });
assert_eq!(shuffled1, shuffled2);
let shuffled3 =
DeterministicOrdering::create_indices(len, &OrderingStrategy::Shuffled { seed: 123 });
assert_ne!(shuffled1, shuffled3);
let custom_indices = DeterministicOrdering::create_indices(
len,
&OrderingStrategy::Custom {
indices: vec![2, 0, 4, 1, 3],
},
);
assert_eq!(custom_indices, vec![2, 0, 4, 1, 3]);
}
#[test]
fn test_environment_capture() {
let manager = SeedManager::new(42);
let env = EnvironmentInfo::capture(&manager);
assert!(!env.rust_version.is_empty());
assert!(!env.os.is_empty());
assert!(!env.arch.is_empty());
assert!(env.num_cpus > 0);
assert_eq!(env.seed_info.master_seed, 42);
}
#[test]
fn test_experiment_tracker() {
let config = ExperimentConfig {
name: "test_experiment".to_string(),
seed: 42,
dataset_config: DatasetConfig {
ordering: OrderingStrategy::Shuffled { seed: 42 },
sampling: SamplingConfig {
strategy: "random".to_string(),
seed: 42,
parameters: HashMap::new(),
},
transforms: Vec::new(),
},
environment: EnvironmentInfo::capture(&SeedManager::new(42)),
metadata: HashMap::new(),
};
let mut tracker = ExperimentTracker::new(config);
tracker.record_operation(
"data_loading".to_string(),
std::time::Duration::from_millis(100),
42,
HashMap::new(),
);
assert_eq!(tracker.operations().len(), 1);
assert_eq!(tracker.operations()[0].name, "data_loading");
assert_eq!(tracker.operations()[0].duration_ms, 100);
let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
let file_path = temp_dir.path().join("experiment.json");
tracker
.save_to_file(&file_path)
.expect("test: save to file should succeed");
let loaded_tracker = ExperimentTracker::load_from_file(&file_path)
.expect("test: load from file should succeed");
assert_eq!(loaded_tracker.config().name, "test_experiment");
assert_eq!(loaded_tracker.operations().len(), 1);
}
#[test]
fn test_reproducibility_ext() {
let features_data = vec![1.0, 2.0, 3.0, 4.0];
let labels_data = vec![0.0, 1.0];
let features =
Tensor::from_vec(features_data, &[2, 2]).expect("test: tensor creation should succeed");
let labels =
Tensor::from_vec(labels_data, &[2]).expect("test: tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let det_dataset = dataset.deterministic(42);
assert_eq!(det_dataset.len(), 2);
let seq_dataset = det_dataset.inner().clone().sequential();
assert_eq!(seq_dataset.indices(), &[0, 1]);
let rev_dataset = det_dataset.inner().clone().reverse();
assert_eq!(rev_dataset.indices(), &[1, 0]);
}
#[test]
fn test_deterministic_ops() {
DeterministicOps::set_global_seed(12345);
let mut rng1 = DeterministicOps::get_rng("test_component");
let val1: f64 = rng1.random();
let mut rng2 = DeterministicOps::get_rng("test_component");
let val2: f64 = rng2.random();
assert_eq!(val1, val2);
let mut rng3 = DeterministicOps::get_rng("other_component");
let val3: f64 = rng3.random();
assert_ne!(val1, val3);
let op1 = DeterministicOps::next_operation_seed();
let op2 = DeterministicOps::next_operation_seed();
assert_ne!(op1, op2);
let env = DeterministicOps::capture_environment();
assert_eq!(env.seed_info.master_seed, 12345);
}
}