#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct SmVersion(pub u32);
impl SmVersion {
pub fn ptx_version_str(self) -> &'static str {
if self.0 >= 100 {
"8.7"
} else if self.0 >= 90 {
"8.4"
} else if self.0 >= 80 {
"8.0"
} else {
"7.5"
}
}
pub fn target_str(self) -> String {
format!("sm_{}", self.0)
}
}
impl std::fmt::Display for SmVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SM {}.{}", self.0 / 10, self.0 % 10)
}
}
#[derive(Debug, Clone)]
pub struct LcgRng {
state: u64,
}
impl LcgRng {
#[must_use]
pub fn new(seed: u64) -> Self {
Self {
state: seed.wrapping_add(1),
}
}
#[inline]
pub fn next_u32(&mut self) -> u32 {
self.state = self
.state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
((self.state >> 33) ^ self.state) as u32
}
#[inline]
pub fn next_f32(&mut self) -> f32 {
self.next_u32() as f32 / (u32::MAX as f32 + 1.0)
}
#[inline]
pub fn next_usize(&mut self, n: usize) -> usize {
(self.next_u32() as usize) % n
}
#[inline]
pub fn next_normal_pair(&mut self) -> (f32, f32) {
let u1 = (self.next_f32() + 1e-10).min(1.0 - 1e-10);
let u2 = self.next_f32();
let r = (-2.0_f32 * u1.ln()).sqrt();
let theta = 2.0 * std::f32::consts::PI * u2;
(r * theta.cos(), r * theta.sin())
}
pub fn shuffle<T>(&mut self, slice: &mut [T]) {
let n = slice.len();
for i in (1..n).rev() {
let j = self.next_usize(i + 1);
slice.swap(i, j);
}
}
}
#[derive(Debug, Clone)]
pub struct GnnHandle {
sm: SmVersion,
device: u32,
rng: LcgRng,
}
impl GnnHandle {
#[must_use]
pub fn new(device: u32, sm: SmVersion, seed: u64) -> Self {
Self {
sm,
device,
rng: LcgRng::new(seed),
}
}
#[must_use]
pub fn default_handle() -> Self {
Self::new(0, SmVersion(80), 42)
}
#[must_use]
#[inline]
pub fn sm_version(&self) -> SmVersion {
self.sm
}
#[must_use]
#[inline]
pub fn device(&self) -> u32 {
self.device
}
#[inline]
pub fn rng_mut(&mut self) -> &mut LcgRng {
&mut self.rng
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sm_version_ptx_strings() {
assert_eq!(SmVersion(75).ptx_version_str(), "7.5");
assert_eq!(SmVersion(80).ptx_version_str(), "8.0");
assert_eq!(SmVersion(86).ptx_version_str(), "8.0");
assert_eq!(SmVersion(90).ptx_version_str(), "8.4");
assert_eq!(SmVersion(100).ptx_version_str(), "8.7");
assert_eq!(SmVersion(120).ptx_version_str(), "8.7");
}
#[test]
fn sm_version_target_str() {
assert_eq!(SmVersion(80).target_str(), "sm_80");
assert_eq!(SmVersion(90).target_str(), "sm_90");
assert_eq!(SmVersion(120).target_str(), "sm_120");
}
#[test]
fn sm_version_display() {
assert_eq!(SmVersion(80).to_string(), "SM 8.0");
assert_eq!(SmVersion(90).to_string(), "SM 9.0");
assert_eq!(SmVersion(120).to_string(), "SM 12.0");
}
#[test]
fn sm_version_ordering() {
assert!(SmVersion(75) < SmVersion(80));
assert!(SmVersion(80) < SmVersion(90));
assert!(SmVersion(100) > SmVersion(90));
assert_eq!(SmVersion(80), SmVersion(80));
}
#[test]
fn gnn_handle_default() {
let h = GnnHandle::default_handle();
assert_eq!(h.device(), 0);
assert_eq!(h.sm_version(), SmVersion(80));
}
#[test]
fn gnn_handle_custom() {
let h = GnnHandle::new(2, SmVersion(90), 123);
assert_eq!(h.device(), 2);
assert_eq!(h.sm_version(), SmVersion(90));
}
#[test]
fn lcg_rng_deterministic() {
let mut r1 = LcgRng::new(42);
let mut r2 = LcgRng::new(42);
for _ in 0..20 {
assert_eq!(r1.next_u32(), r2.next_u32());
}
}
#[test]
fn lcg_rng_f32_range() {
let mut rng = LcgRng::new(99);
for _ in 0..1000 {
let v = rng.next_f32();
assert!((0.0..1.0).contains(&v));
}
}
#[test]
fn lcg_rng_usize_range() {
let mut rng = LcgRng::new(7);
for _ in 0..1000 {
let v = rng.next_usize(10);
assert!(v < 10);
}
}
#[test]
fn lcg_rng_shuffle_permutation() {
let mut rng = LcgRng::new(5);
let mut v: Vec<usize> = (0..8).collect();
let original = v.clone();
rng.shuffle(&mut v);
let mut sorted = v.clone();
sorted.sort_unstable();
assert_eq!(sorted, original);
}
#[test]
fn lcg_rng_normal_pair_finite() {
let mut rng = LcgRng::new(13);
for _ in 0..1000 {
let (a, b) = rng.next_normal_pair();
assert!(a.is_finite());
assert!(b.is_finite());
}
}
#[test]
fn lcg_rng_normal_pair_spans_both_signs() {
let mut rng = LcgRng::new(2024);
let mut saw_positive = false;
let mut saw_negative = false;
for _ in 0..5000 {
let (a, _) = rng.next_normal_pair();
if a > 0.0 {
saw_positive = true;
}
if a < 0.0 {
saw_negative = true;
}
assert!(a.abs() < 12.0, "unreasonable magnitude: {a}");
}
assert!(saw_positive && saw_negative);
}
#[test]
fn lcg_rng_normal_pair_deterministic() {
let mut r1 = LcgRng::new(321);
let mut r2 = LcgRng::new(321);
for _ in 0..50 {
assert_eq!(r1.next_normal_pair(), r2.next_normal_pair());
}
}
#[test]
fn gnn_handle_rng_mut() {
let mut h = GnnHandle::default_handle();
let v1 = h.rng_mut().next_u32();
let v2 = h.rng_mut().next_u32();
assert_ne!(v1, v2);
}
}