use std::collections::HashSet;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MemoryOrdering {
#[default]
Weak,
Relaxed,
Acquire,
Release,
}
impl MemoryOrdering {
#[must_use]
pub const fn to_ptx_modifier(self) -> &'static str {
match self {
Self::Weak => ".weak",
Self::Relaxed => ".relaxed",
Self::Acquire => ".acquire",
Self::Release => ".release",
}
}
#[must_use]
pub const fn is_acquire(self) -> bool {
matches!(self, Self::Acquire)
}
#[must_use]
pub const fn is_release(self) -> bool {
matches!(self, Self::Release)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MemoryScope {
Thread,
Block,
Cluster,
#[default]
Device,
System,
}
impl MemoryScope {
#[must_use]
pub const fn to_ptx_scope(self) -> &'static str {
match self {
Self::Thread | Self::Block => ".cta",
Self::Cluster => ".cluster",
Self::Device => ".gpu",
Self::System => ".sys",
}
}
}
static NEXT_TOKEN_ID: AtomicU64 = AtomicU64::new(1);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Token {
id: u64,
}
impl Token {
#[must_use]
pub fn new() -> Self {
Self {
id: NEXT_TOKEN_ID.fetch_add(1, Ordering::Relaxed),
}
}
#[must_use]
pub const fn id(self) -> u64 {
self.id
}
#[must_use]
pub const fn from_id(id: u64) -> Self {
Self { id }
}
}
impl Default for Token {
fn default() -> Self {
Self::new()
}
}
#[must_use]
pub fn join_tokens(_tokens: &[Token]) -> Token {
Token::new()
}
#[derive(Debug, Clone, Default)]
pub struct TokenGraph {
tokens: HashSet<u64>,
dependencies: Vec<(u64, Vec<u64>)>,
}
impl TokenGraph {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn create_token(&mut self, token: Token) {
self.tokens.insert(token.id());
}
pub fn add_dependency(&mut self, dependent: Token, dependency: Token) {
if let Some((_, deps)) = self
.dependencies
.iter_mut()
.find(|(d, _)| *d == dependent.id())
{
deps.push(dependency.id());
} else {
self.dependencies
.push((dependent.id(), vec![dependency.id()]));
}
}
pub fn join(&mut self, result: Token, sources: &[Token]) {
let deps: Vec<u64> = sources.iter().map(|t| t.id()).collect();
self.dependencies.push((result.id(), deps));
self.tokens.insert(result.id());
}
#[must_use]
pub fn has_dependencies(&self, token: Token) -> bool {
self.dependencies.iter().any(|(d, _)| *d == token.id())
}
#[must_use]
pub fn get_dependencies(&self, token: Token) -> Vec<u64> {
self.dependencies
.iter()
.find(|(d, _)| *d == token.id())
.map(|(_, deps)| deps.clone())
.unwrap_or_default()
}
#[must_use]
pub fn has_cycle(&self) -> bool {
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
for &token_id in &self.tokens {
if self.has_cycle_dfs(token_id, &mut visited, &mut rec_stack) {
return true;
}
}
false
}
fn has_cycle_dfs(
&self,
token_id: u64,
visited: &mut HashSet<u64>,
rec_stack: &mut HashSet<u64>,
) -> bool {
if rec_stack.contains(&token_id) {
return true; }
if visited.contains(&token_id) {
return false; }
visited.insert(token_id);
rec_stack.insert(token_id);
if let Some((_, deps)) = self.dependencies.iter().find(|(d, _)| *d == token_id) {
for &dep in deps {
if self.has_cycle_dfs(dep, visited, rec_stack) {
return true;
}
}
}
rec_stack.remove(&token_id);
false
}
#[must_use]
pub fn token_count(&self) -> usize {
self.tokens.len()
}
}
pub struct TkoAnalysis {
pub graph: TokenGraph,
pub eliminable_barriers: Vec<usize>,
}
impl TkoAnalysis {
#[must_use]
pub fn new() -> Self {
Self {
graph: TokenGraph::new(),
eliminable_barriers: Vec::new(),
}
}
#[must_use]
pub fn is_sound(&self) -> bool {
!self.graph.has_cycle()
}
#[must_use]
pub fn eliminable_count(&self) -> usize {
self.eliminable_barriers.len()
}
}
impl Default for TkoAnalysis {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_creation() {
let t1 = Token::new();
let t2 = Token::new();
assert_ne!(t1.id(), t2.id());
}
#[test]
fn test_join_tokens() {
let t1 = Token::new();
let t2 = Token::new();
let t3 = Token::new();
let joined = join_tokens(&[t1, t2, t3]);
assert_ne!(joined.id(), t1.id());
assert_ne!(joined.id(), t2.id());
assert_ne!(joined.id(), t3.id());
}
#[test]
fn test_memory_ordering_relaxed_fastest() {
let weak = MemoryOrdering::Weak;
let relaxed = MemoryOrdering::Relaxed;
let acquire = MemoryOrdering::Acquire;
assert_eq!(weak.to_ptx_modifier(), ".weak");
assert_eq!(relaxed.to_ptx_modifier(), ".relaxed");
assert_eq!(acquire.to_ptx_modifier(), ".acquire");
}
#[test]
fn test_cycle_detection() {
let mut graph = TokenGraph::new();
let t1 = Token::new();
let t2 = Token::new();
let t3 = Token::new();
graph.create_token(t1);
graph.create_token(t2);
graph.create_token(t3);
graph.add_dependency(t2, t1);
graph.add_dependency(t3, t2);
graph.add_dependency(t1, t3);
assert!(graph.has_cycle(), "Should detect cycle");
}
#[test]
fn test_no_cycle() {
let mut graph = TokenGraph::new();
let t1 = Token::new();
let t2 = Token::new();
let t3 = Token::new();
graph.create_token(t1);
graph.create_token(t2);
graph.create_token(t3);
graph.add_dependency(t2, t1);
graph.add_dependency(t3, t2);
assert!(!graph.has_cycle(), "Should not detect cycle");
}
#[test]
fn test_tko_analysis_sound() {
let analysis = TkoAnalysis::new();
assert!(analysis.is_sound(), "Empty analysis should be sound");
}
#[test]
fn test_memory_scope_ptx() {
assert_eq!(MemoryScope::Thread.to_ptx_scope(), ".cta");
assert_eq!(MemoryScope::Block.to_ptx_scope(), ".cta");
assert_eq!(MemoryScope::Cluster.to_ptx_scope(), ".cluster");
assert_eq!(MemoryScope::Device.to_ptx_scope(), ".gpu");
assert_eq!(MemoryScope::System.to_ptx_scope(), ".sys");
}
#[test]
fn test_token_graph_join() {
let mut graph = TokenGraph::new();
let t1 = Token::new();
let t2 = Token::new();
let result = Token::new();
graph.create_token(t1);
graph.create_token(t2);
graph.join(result, &[t1, t2]);
assert!(graph.has_dependencies(result));
assert_eq!(graph.get_dependencies(result).len(), 2);
}
#[test]
fn test_empty_join() {
let joined = join_tokens(&[]);
assert!(joined.id() > 0);
}
#[test]
fn test_single_token_join() {
let t1 = Token::new();
let joined = join_tokens(&[t1]);
assert_ne!(joined.id(), t1.id());
}
#[test]
fn test_token_from_id() {
let t = Token::from_id(42);
assert_eq!(t.id(), 42);
}
#[test]
fn test_memory_ordering_acquire_release() {
let acquire = MemoryOrdering::Acquire;
let release = MemoryOrdering::Release;
assert!(acquire.is_acquire());
assert!(!acquire.is_release());
assert!(release.is_release());
assert!(!release.is_acquire());
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn token_ids_unique(count in 1usize..100) {
let tokens: Vec<Token> = (0..count).map(|_| Token::new()).collect();
let ids: std::collections::HashSet<u64> = tokens.iter().map(|t| t.id()).collect();
prop_assert_eq!(ids.len(), tokens.len());
}
#[test]
fn join_produces_unique_token(count in 0usize..20) {
let tokens: Vec<Token> = (0..count).map(|_| Token::new()).collect();
let joined = join_tokens(&tokens);
for t in &tokens {
prop_assert_ne!(joined.id(), t.id());
}
}
#[test]
fn linear_graph_has_no_cycle(count in 2usize..20) {
let mut graph = TokenGraph::new();
let tokens: Vec<Token> = (0..count).map(|_| Token::new()).collect();
for t in &tokens {
graph.create_token(*t);
}
for i in 1..tokens.len() {
graph.add_dependency(tokens[i], tokens[i - 1]);
}
prop_assert!(!graph.has_cycle());
}
#[test]
fn memory_ordering_ptx_modifiers_nonempty(_dummy in 0u8..4) {
let orderings = [
MemoryOrdering::Weak,
MemoryOrdering::Relaxed,
MemoryOrdering::Acquire,
MemoryOrdering::Release,
];
for ordering in orderings {
let modifier = ordering.to_ptx_modifier();
prop_assert!(!modifier.is_empty());
prop_assert!(modifier.starts_with('.'));
}
}
#[test]
fn memory_scope_ptx_scopes_nonempty(_dummy in 0u8..5) {
let scopes = [
MemoryScope::Thread,
MemoryScope::Block,
MemoryScope::Cluster,
MemoryScope::Device,
MemoryScope::System,
];
for scope in scopes {
let ptx_scope = scope.to_ptx_scope();
prop_assert!(!ptx_scope.is_empty());
prop_assert!(ptx_scope.starts_with('.'));
}
}
#[test]
fn token_from_id_preserves(id in 1u64..u64::MAX) {
let t = Token::from_id(id);
prop_assert_eq!(t.id(), id);
}
}
}