use crate::{Result, TensorError};
use std::sync::{Arc, Mutex, OnceLock};
#[derive(Debug, Clone)]
pub struct DeterministicState {
pub enabled: bool,
pub global_seed: u64,
pub operation_counter: u64,
pub strict_mode: bool,
pub prefer_deterministic_algorithms: bool,
pub operation_log: Vec<String>,
pub max_log_size: usize,
}
impl Default for DeterministicState {
fn default() -> Self {
Self {
enabled: false,
global_seed: 0,
operation_counter: 0,
strict_mode: false,
prefer_deterministic_algorithms: true,
operation_log: Vec::new(),
max_log_size: 1000,
}
}
}
impl DeterministicState {
pub fn new(seed: u64) -> Self {
Self {
enabled: true,
global_seed: seed,
..Default::default()
}
}
pub fn next_subseed(&mut self, operation_name: &str) -> u64 {
let subseed = self
.global_seed
.wrapping_mul(6364136223846793005)
.wrapping_add(self.operation_counter)
.wrapping_add(hash_string(operation_name));
self.operation_counter += 1;
if self.operation_log.len() < self.max_log_size {
self.operation_log
.push(format!("{}: seed={}", operation_name, subseed));
}
subseed
}
pub fn reset_counter(&mut self) {
self.operation_counter = 0;
}
pub fn clear_log(&mut self) {
self.operation_log.clear();
}
pub fn snapshot(&self) -> DeterministicSnapshot {
DeterministicSnapshot {
global_seed: self.global_seed,
operation_counter: self.operation_counter,
enabled: self.enabled,
}
}
pub fn restore(&mut self, snapshot: &DeterministicSnapshot) {
self.global_seed = snapshot.global_seed;
self.operation_counter = snapshot.operation_counter;
self.enabled = snapshot.enabled;
}
}
#[derive(Debug, Clone, Copy)]
pub struct DeterministicSnapshot {
pub global_seed: u64,
pub operation_counter: u64,
pub enabled: bool,
}
fn hash_string(s: &str) -> u64 {
let mut hash = 0xcbf29ce484222325u64; for byte in s.bytes() {
hash ^= byte as u64;
hash = hash.wrapping_mul(0x100000001b3); }
hash
}
static GLOBAL_STATE: OnceLock<Arc<Mutex<DeterministicState>>> = OnceLock::new();
fn get_global_state() -> &'static Arc<Mutex<DeterministicState>> {
GLOBAL_STATE.get_or_init(|| Arc::new(Mutex::new(DeterministicState::default())))
}
pub fn set_deterministic_mode(enabled: bool) {
let state = get_global_state();
state.lock().expect("lock should not be poisoned").enabled = enabled;
}
pub fn is_deterministic_mode() -> bool {
let state = get_global_state();
state.lock().expect("lock should not be poisoned").enabled
}
pub fn set_global_seed(seed: u64) {
let state = get_global_state();
let mut s = state.lock().expect("lock should not be poisoned");
s.global_seed = seed;
s.operation_counter = 0;
s.clear_log();
}
pub fn get_global_seed() -> u64 {
let state = get_global_state();
state
.lock()
.expect("lock should not be poisoned")
.global_seed
}
pub fn set_strict_mode(strict: bool) {
let state = get_global_state();
state
.lock()
.expect("lock should not be poisoned")
.strict_mode = strict;
}
pub fn is_strict_mode() -> bool {
let state = get_global_state();
state
.lock()
.expect("lock should not be poisoned")
.strict_mode
}
pub fn get_operation_seed(operation_name: &str) -> u64 {
let state = get_global_state();
let mut s = state.lock().expect("lock should not be poisoned");
if !s.enabled {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time should be after UNIX_EPOCH")
.as_nanos() as u64
} else {
s.next_subseed(operation_name)
}
}
pub fn reset_operation_counter() {
let state = get_global_state();
state
.lock()
.expect("lock should not be poisoned")
.reset_counter();
}
pub fn get_state_snapshot() -> DeterministicSnapshot {
let state = get_global_state();
state
.lock()
.expect("lock should not be poisoned")
.snapshot()
}
pub fn restore_state_snapshot(snapshot: &DeterministicSnapshot) {
let state = get_global_state();
state
.lock()
.expect("lock should not be poisoned")
.restore(snapshot);
}
pub fn get_operation_log() -> Vec<String> {
let state = get_global_state();
state
.lock()
.expect("lock should not be poisoned")
.operation_log
.clone()
}
pub fn clear_operation_log() {
let state = get_global_state();
state
.lock()
.expect("lock should not be poisoned")
.clear_log();
}
pub fn enable_operation_logging() {
let state = get_global_state();
let mut s = state.lock().expect("lock should not be poisoned");
s.max_log_size = 1000;
}
#[doc(hidden)]
pub fn reset_to_defaults() {
let state = get_global_state();
let mut s = state.lock().expect("lock should not be poisoned");
*s = DeterministicState::default();
}
pub struct DeterministicScope {
previous_state: DeterministicSnapshot,
}
impl DeterministicScope {
pub fn new(seed: u64) -> Self {
let previous_state = get_state_snapshot();
set_deterministic_mode(true);
set_global_seed(seed);
Self { previous_state }
}
pub fn with_mode(enabled: bool) -> Self {
let previous_state = get_state_snapshot();
set_deterministic_mode(enabled);
Self { previous_state }
}
}
impl Drop for DeterministicScope {
fn drop(&mut self) {
restore_state_snapshot(&self.previous_state);
}
}
#[derive(Debug, Clone)]
pub struct DeterministicConfig {
pub seed: u64,
pub strict: bool,
pub prefer_deterministic: bool,
pub log_operations: bool,
}
impl Default for DeterministicConfig {
fn default() -> Self {
Self {
seed: 42,
strict: false,
prefer_deterministic: true,
log_operations: false,
}
}
}
impl DeterministicConfig {
pub fn apply(&self) {
set_global_seed(self.seed);
set_deterministic_mode(true);
set_strict_mode(self.strict);
let state = get_global_state();
let mut s = state.lock().expect("lock should not be poisoned");
s.prefer_deterministic_algorithms = self.prefer_deterministic;
if !self.log_operations {
s.clear_log();
s.max_log_size = 0;
} else {
s.max_log_size = 1000;
}
}
}
pub fn verify_reproducibility<F, T>(operation_name: &str, mut operation: F) -> Result<bool>
where
F: FnMut() -> T,
T: PartialEq,
{
let snapshot = get_state_snapshot();
set_global_seed(snapshot.global_seed);
reset_operation_counter();
let result1 = operation();
set_global_seed(snapshot.global_seed);
reset_operation_counter();
let result2 = operation();
restore_state_snapshot(&snapshot);
Ok(result1 == result2)
}
pub fn mark_non_deterministic(operation_name: &str) -> Result<()> {
if is_deterministic_mode() && is_strict_mode() {
Err(TensorError::invalid_operation_simple(format!(
"Operation '{}' is non-deterministic but strict deterministic mode is enabled",
operation_name
)))
} else {
if is_deterministic_mode() {
eprintln!(
"Warning: Operation '{}' may not be fully deterministic",
operation_name
);
}
Ok(())
}
}
pub fn should_use_deterministic_gpu_ops() -> bool {
let state = get_global_state();
let s = state.lock().expect("lock should not be poisoned");
s.enabled && s.prefer_deterministic_algorithms
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
lazy_static::lazy_static! {
static ref TEST_MUTEX: Mutex<()> = Mutex::new(());
}
#[test]
fn test_deterministic_mode_toggle() {
let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
reset_to_defaults();
set_deterministic_mode(true);
assert!(is_deterministic_mode());
set_deterministic_mode(false);
assert!(!is_deterministic_mode());
}
#[test]
fn test_global_seed() {
let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
reset_to_defaults();
set_global_seed(12345);
assert_eq!(get_global_seed(), 12345);
set_global_seed(67890);
assert_eq!(get_global_seed(), 67890);
}
#[test]
fn test_operation_seed_generation() {
let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
reset_to_defaults();
set_deterministic_mode(true);
set_global_seed(42);
let seed1 = get_operation_seed("test_op");
let seed2 = get_operation_seed("test_op");
assert_ne!(seed1, seed2);
reset_operation_counter();
let seed3 = get_operation_seed("test_op");
assert_eq!(seed1, seed3);
}
#[test]
fn test_operation_seed_uniqueness() {
let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
reset_to_defaults();
set_deterministic_mode(true);
set_global_seed(42);
reset_operation_counter();
let seed_a = get_operation_seed("operation_a");
let seed_b = get_operation_seed("operation_b");
assert_ne!(seed_a, seed_b);
}
#[test]
fn test_snapshot_and_restore() {
let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
reset_to_defaults();
set_deterministic_mode(true);
set_global_seed(100);
let _ = get_operation_seed("op1");
let _ = get_operation_seed("op2");
let snapshot = get_state_snapshot();
let _ = get_operation_seed("op3");
restore_state_snapshot(&snapshot);
let seed_after_restore = get_operation_seed("op3");
restore_state_snapshot(&snapshot);
let seed_repeat = get_operation_seed("op3");
assert_eq!(seed_after_restore, seed_repeat);
}
#[test]
fn test_deterministic_scope() {
let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
reset_to_defaults();
set_deterministic_mode(false);
set_global_seed(100);
{
let _scope = DeterministicScope::new(200);
assert!(is_deterministic_mode());
assert_eq!(get_global_seed(), 200);
}
assert!(!is_deterministic_mode());
assert_eq!(get_global_seed(), 100);
}
#[test]
fn test_strict_mode() {
let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
reset_to_defaults();
set_strict_mode(true);
assert!(is_strict_mode());
set_strict_mode(false);
assert!(!is_strict_mode());
}
#[test]
fn test_mark_non_deterministic() {
let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
reset_to_defaults();
set_deterministic_mode(true);
set_strict_mode(false);
assert!(mark_non_deterministic("test_op").is_ok());
set_strict_mode(true);
assert!(mark_non_deterministic("test_op").is_err());
}
#[test]
fn test_config_apply() {
let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
reset_to_defaults();
let config = DeterministicConfig {
seed: 777,
strict: true,
prefer_deterministic: true,
log_operations: false,
};
config.apply();
assert_eq!(get_global_seed(), 777);
assert!(is_deterministic_mode());
assert!(is_strict_mode());
}
#[test]
fn test_operation_log() {
let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
reset_to_defaults();
enable_operation_logging();
set_deterministic_mode(true);
set_global_seed(42);
let _ = get_operation_seed("op1");
let _ = get_operation_seed("op2");
let log = get_operation_log();
assert_eq!(log.len(), 2);
assert!(log[0].contains("op1"));
assert!(log[1].contains("op2"));
}
#[test]
fn test_hash_string_deterministic() {
let hash1 = hash_string("test");
let hash2 = hash_string("test");
assert_eq!(hash1, hash2);
let hash3 = hash_string("different");
assert_ne!(hash1, hash3);
}
#[test]
fn test_reproducibility_with_counter_reset() {
let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
reset_to_defaults();
set_deterministic_mode(true);
set_global_seed(42);
reset_operation_counter();
let seeds1: Vec<u64> = (0..5)
.map(|i| get_operation_seed(&format!("op{}", i)))
.collect();
reset_operation_counter();
let seeds2: Vec<u64> = (0..5)
.map(|i| get_operation_seed(&format!("op{}", i)))
.collect();
assert_eq!(seeds1, seeds2);
}
#[test]
fn test_non_deterministic_mode_uses_system_time() {
let _guard = TEST_MUTEX.lock().expect("lock should not be poisoned");
reset_to_defaults();
set_deterministic_mode(false);
let seed1 = get_operation_seed("test");
std::thread::sleep(std::time::Duration::from_nanos(100));
let seed2 = get_operation_seed("test");
let _ = seed1;
let _ = seed2;
}
}