use std::collections::HashSet;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MemoryOrdering {
#[default]
Weak,
Relaxed,
Acquire,
Release,
}
impl MemoryOrdering {
#[must_use]
pub const fn to_ptx_modifier(self) -> &'static str {
match self {
Self::Weak => ".weak",
Self::Relaxed => ".relaxed",
Self::Acquire => ".acquire",
Self::Release => ".release",
}
}
#[must_use]
pub const fn is_acquire(self) -> bool {
matches!(self, Self::Acquire)
}
#[must_use]
pub const fn is_release(self) -> bool {
matches!(self, Self::Release)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MemoryScope {
Thread,
Block,
Cluster,
#[default]
Device,
System,
}
impl MemoryScope {
#[must_use]
pub const fn to_ptx_scope(self) -> &'static str {
match self {
Self::Thread | Self::Block => ".cta",
Self::Cluster => ".cluster",
Self::Device => ".gpu",
Self::System => ".sys",
}
}
}
static NEXT_TOKEN_ID: AtomicU64 = AtomicU64::new(1);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Token {
id: u64,
}
impl Token {
#[must_use]
pub fn new() -> Self {
Self { id: NEXT_TOKEN_ID.fetch_add(1, Ordering::Relaxed) }
}
#[must_use]
pub const fn id(self) -> u64 {
self.id
}
#[must_use]
pub const fn from_id(id: u64) -> Self {
Self { id }
}
}
impl Default for Token {
fn default() -> Self {
Self::new()
}
}
#[must_use]
pub fn join_tokens(_tokens: &[Token]) -> Token {
Token::new()
}
#[derive(Debug, Clone, Default)]
pub struct TokenGraph {
tokens: HashSet<u64>,
dependencies: Vec<(u64, Vec<u64>)>,
}
impl TokenGraph {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn create_token(&mut self, token: Token) {
self.tokens.insert(token.id());
}
pub fn add_dependency(&mut self, dependent: Token, dependency: Token) {
if let Some((_, deps)) = self.dependencies.iter_mut().find(|(d, _)| *d == dependent.id()) {
deps.push(dependency.id());
} else {
self.dependencies.push((dependent.id(), vec![dependency.id()]));
}
}
pub fn join(&mut self, result: Token, sources: &[Token]) {
let deps: Vec<u64> = sources.iter().map(|t| t.id()).collect();
self.dependencies.push((result.id(), deps));
self.tokens.insert(result.id());
}
#[must_use]
pub fn has_dependencies(&self, token: Token) -> bool {
self.dependencies.iter().any(|(d, _)| *d == token.id())
}
#[must_use]
pub fn get_dependencies(&self, token: Token) -> Vec<u64> {
self.dependencies
.iter()
.find(|(d, _)| *d == token.id())
.map(|(_, deps)| deps.clone())
.unwrap_or_default()
}
#[must_use]
pub fn has_cycle(&self) -> bool {
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
for &token_id in &self.tokens {
if self.has_cycle_dfs(token_id, &mut visited, &mut rec_stack) {
return true;
}
}
false
}
fn has_cycle_dfs(
&self,
token_id: u64,
visited: &mut HashSet<u64>,
rec_stack: &mut HashSet<u64>,
) -> bool {
if rec_stack.contains(&token_id) {
return true; }
if visited.contains(&token_id) {
return false; }
visited.insert(token_id);
rec_stack.insert(token_id);
if let Some((_, deps)) = self.dependencies.iter().find(|(d, _)| *d == token_id) {
for &dep in deps {
if self.has_cycle_dfs(dep, visited, rec_stack) {
return true;
}
}
}
rec_stack.remove(&token_id);
false
}
#[must_use]
pub fn token_count(&self) -> usize {
self.tokens.len()
}
}
pub struct TkoAnalysis {
pub graph: TokenGraph,
pub eliminable_barriers: Vec<usize>,
}
impl TkoAnalysis {
#[must_use]
pub fn new() -> Self {
Self { graph: TokenGraph::new(), eliminable_barriers: Vec::new() }
}
#[must_use]
pub fn is_sound(&self) -> bool {
!self.graph.has_cycle()
}
#[must_use]
pub fn eliminable_count(&self) -> usize {
self.eliminable_barriers.len()
}
}
impl Default for TkoAnalysis {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests;
#[cfg(test)]
mod property_tests;