use std::sync::atomic::{AtomicBool, Ordering};
pub use gllm_kernels::ops::flash_attention::DeterministicConfig;
static DETERMINISTIC_MODE: AtomicBool = AtomicBool::new(false);
pub trait DeterministicConfigExt {
fn apply_global(&self);
}
impl DeterministicConfigExt for DeterministicConfig {
fn apply_global(&self) {
DETERMINISTIC_MODE.store(self.no_gpu_nondeterminism, Ordering::SeqCst);
}
}
pub struct DeterministicGuard {
previous_mode: bool,
#[allow(dead_code)]
config: DeterministicConfig,
}
impl DeterministicGuard {
pub fn new(config: DeterministicConfig) -> Self {
let previous_mode = DETERMINISTIC_MODE.load(Ordering::SeqCst);
config.apply_global();
Self {
previous_mode,
config,
}
}
pub fn strict() -> Self {
Self::new(DeterministicConfig::strict())
}
pub fn is_active(&self) -> bool {
DETERMINISTIC_MODE.load(Ordering::SeqCst)
}
}
impl Drop for DeterministicGuard {
fn drop(&mut self) {
DETERMINISTIC_MODE.store(self.previous_mode, Ordering::SeqCst);
}
}
pub fn is_deterministic_mode() -> bool {
DETERMINISTIC_MODE.load(Ordering::SeqCst)
}
pub trait DeterministicExecution {
fn execute_deterministic(&self, config: &DeterministicConfig) -> Self;
}
#[cfg(debug_assertions)]
pub fn verify_deterministic<T, F>(config: &DeterministicConfig, f: F) -> T
where
T: PartialEq + std::fmt::Debug + Clone,
F: Fn() -> T,
{
if !config.verify_determinism {
return f();
}
let result1 = f();
let result2 = f();
assert_eq!(
result1, result2,
"Non-deterministic behavior detected! Results differ between runs."
);
result1
}
#[cfg(not(debug_assertions))]
pub fn verify_deterministic<T, F>(_config: &DeterministicConfig, f: F) -> T
where
F: Fn() -> T,
{
f()
}
pub struct StrictOrderIterator<I> {
inner: I,
index: usize,
}
impl<I: Iterator> StrictOrderIterator<I> {
pub fn new(iter: I) -> Self {
Self {
inner: iter,
index: 0,
}
}
pub fn current_index(&self) -> usize {
self.index
}
}
impl<I: Iterator> Iterator for StrictOrderIterator<I> {
type Item = (usize, I::Item);
fn next(&mut self) -> Option<Self::Item> {
let item = self.inner.next()?;
let index = self.index;
self.index += 1;
std::sync::atomic::fence(Ordering::SeqCst);
Some((index, item))
}
}
pub trait StrictOrderExt: Iterator + Sized {
fn strict_order(self) -> StrictOrderIterator<Self> {
StrictOrderIterator::new(self)
}
}
impl<I: Iterator> StrictOrderExt for I {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deterministic_config_default() {
let config = DeterministicConfig::default();
assert!(config.strict_order);
assert!(config.deterministic_rng);
assert!(config.no_gpu_nondeterminism);
assert!(config.is_deterministic());
}
#[test]
fn test_deterministic_config_relaxed() {
let config = DeterministicConfig::relaxed();
assert!(!config.strict_order);
assert!(!config.deterministic_rng);
assert!(!config.no_gpu_nondeterminism);
assert!(!config.is_deterministic());
}
#[test]
fn test_deterministic_guard() {
DETERMINISTIC_MODE.store(false, Ordering::SeqCst);
assert!(!is_deterministic_mode());
{
let guard = DeterministicGuard::strict();
assert!(guard.is_active());
assert!(is_deterministic_mode());
}
assert!(!is_deterministic_mode());
}
#[test]
fn test_strict_order_iterator() {
let items = vec![10, 20, 30, 40, 50];
let collected: Vec<_> = items.iter().strict_order().collect();
assert_eq!(collected.len(), 5);
for (i, (idx, &val)) in collected.iter().enumerate() {
assert_eq!(*idx, i);
assert_eq!(val, items[i]);
}
}
#[test]
#[cfg(debug_assertions)]
fn test_verify_deterministic() {
let config = DeterministicConfig {
verify_determinism: true,
..Default::default()
};
let result = verify_deterministic(&config, || 42);
assert_eq!(result, 42);
}
}