1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#![feature(platform_intrinsics)]
#![no_std]

pub struct Dim3 {
    x: i32,
    y: i32,
    z: i32,
}

pub struct Idx3 {
    x: i32,
    y: i32,
    z: i32,
}

#[inline(always)]
pub fn block_dim() -> Dim3 {
    unsafe {
        Dim3 {
            x: nvptx_block_dim_x(),
            y: nvptx_block_dim_y(),
            z: nvptx_block_dim_z(),
        }
    }
}

#[inline(always)]
pub fn block_idx() -> Idx3 {
    unsafe {
        Idx3 {
            x: nvptx_block_idx_x(),
            y: nvptx_block_idx_y(),
            z: nvptx_block_idx_z(),
        }
    }
}

#[inline(always)]
pub fn grid_dim() -> Dim3 {
    unsafe {
        Dim3 {
            x: nvptx_grid_dim_x(),
            y: nvptx_grid_dim_y(),
            z: nvptx_grid_dim_z(),
        }
    }
}

#[inline(always)]
pub fn thread_idx() -> Idx3 {
    unsafe {
        Idx3 {
            x: nvptx_thread_idx_x(),
            y: nvptx_thread_idx_y(),
            z: nvptx_thread_idx_z(),
        }
    }
}

impl Dim3 {
    #[inline(always)]
    pub fn size(&self) -> i32 {
        (self.x * self.y * self.z)
    }
}

impl Idx3 {
    #[inline(always)]
    pub fn into_id(&self, dim: Dim3) -> i32 {
        self.x + self.y * dim.x + self.z * dim.x * dim.y
    }
}

extern "platform-intrinsic" {
    pub fn nvptx_block_dim_x() -> i32;
    pub fn nvptx_block_dim_y() -> i32;
    pub fn nvptx_block_dim_z() -> i32;
    pub fn nvptx_block_idx_x() -> i32;
    pub fn nvptx_block_idx_y() -> i32;
    pub fn nvptx_block_idx_z() -> i32;
    pub fn nvptx_grid_dim_x() -> i32;
    pub fn nvptx_grid_dim_y() -> i32;
    pub fn nvptx_grid_dim_z() -> i32;
    pub fn nvptx_syncthreads() -> ();
    pub fn nvptx_thread_idx_x() -> i32;
    pub fn nvptx_thread_idx_y() -> i32;
    pub fn nvptx_thread_idx_z() -> i32;
}

#[inline(always)]
pub fn index() -> isize {
    let block_id = block_idx().into_id(grid_dim());
    let thread_id = thread_idx().into_id(block_dim());
    (block_id + thread_id) as isize
}