1#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5pub enum Operation {
6 Elementwise,
8 Dot,
10 Matmul,
12 Conv2d,
14 Attention,
16 Softmax,
18 LayerNorm,
20 Ffn,
22 Reduce,
24 Custom(u32),
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub enum DataType {
31 F32,
33 F16,
35 Bf16,
37 I8,
39 U8,
41 Q4,
43}
44
45impl DataType {
46 pub fn byte_size(&self) -> usize {
48 match self {
49 DataType::F32 => 4,
50 DataType::F16 | DataType::Bf16 => 2,
51 DataType::I8 | DataType::U8 => 1,
52 DataType::Q4 => 1, }
54 }
55}
56
57#[derive(Debug, Clone, PartialEq)]
59pub struct Dimensions {
60 pub batch: usize,
62 pub seq_len: usize,
64 pub num_heads: usize,
66 pub head_dim: usize,
68 pub hidden_dim: usize,
70 pub m: usize,
72 pub n: usize,
74 pub k: usize,
76}
77
78impl Default for Dimensions {
79 fn default() -> Self {
80 Self {
81 batch: 1,
82 seq_len: 1,
83 num_heads: 1,
84 head_dim: 64,
85 hidden_dim: 1,
86 m: 1,
87 n: 1,
88 k: 1,
89 }
90 }
91}
92
93impl Dimensions {
94 pub fn vector(size: usize) -> Self {
96 Self {
97 n: size,
98 ..Default::default()
99 }
100 }
101
102 pub fn matmul(m: usize, n: usize, k: usize) -> Self {
104 Self {
105 m,
106 n,
107 k,
108 ..Default::default()
109 }
110 }
111
112 pub fn attention(batch: usize, seq_len: usize, num_heads: usize, head_dim: usize) -> Self {
114 Self {
115 batch,
116 seq_len,
117 num_heads,
118 head_dim,
119 ..Default::default()
120 }
121 }
122}
123
124#[derive(Debug, Clone, PartialEq)]
126pub struct TensorSpec {
127 pub name: String,
129 pub shape: Vec<usize>,
131 pub dtype: DataType,
133 pub stride: Option<Vec<usize>>,
135}
136
137impl TensorSpec {
138 pub fn new(name: impl Into<String>, shape: Vec<usize>, dtype: DataType) -> Self {
140 Self {
141 name: name.into(),
142 shape,
143 dtype,
144 stride: None,
145 }
146 }
147
148 pub fn numel(&self) -> usize {
150 self.shape.iter().product()
151 }
152
153 pub fn byte_size(&self) -> usize {
155 self.numel() * self.dtype.byte_size()
156 }
157}
158
159#[derive(Debug, Clone, PartialEq)]
161pub struct WorkloadSpec {
162 pub operation: Operation,
164 pub dimensions: Dimensions,
166 pub dtype: DataType,
168 pub inputs: Vec<TensorSpec>,
170 pub outputs: Vec<TensorSpec>,
172}
173
174impl WorkloadSpec {
175 pub fn dot(size: usize) -> Self {
177 Self {
178 operation: Operation::Dot,
179 dimensions: Dimensions::vector(size),
180 dtype: DataType::F32,
181 inputs: vec![
182 TensorSpec::new("a", vec![size], DataType::F32),
183 TensorSpec::new("b", vec![size], DataType::F32),
184 ],
185 outputs: vec![TensorSpec::new("result", vec![1], DataType::F32)],
186 }
187 }
188
189 pub fn matmul(m: usize, n: usize, k: usize) -> Self {
191 Self {
192 operation: Operation::Matmul,
193 dimensions: Dimensions::matmul(m, n, k),
194 dtype: DataType::F32,
195 inputs: vec![
196 TensorSpec::new("a", vec![m, k], DataType::F32),
197 TensorSpec::new("b", vec![k, n], DataType::F32),
198 ],
199 outputs: vec![TensorSpec::new("c", vec![m, n], DataType::F32)],
200 }
201 }
202
203 pub fn attention(batch: usize, seq_len: usize, num_heads: usize, head_dim: usize) -> Self {
205 let embed_dim = num_heads * head_dim;
206 Self {
207 operation: Operation::Attention,
208 dimensions: Dimensions::attention(batch, seq_len, num_heads, head_dim),
209 dtype: DataType::F32,
210 inputs: vec![
211 TensorSpec::new("q", vec![batch, seq_len, embed_dim], DataType::F32),
212 TensorSpec::new("k", vec![batch, seq_len, embed_dim], DataType::F32),
213 TensorSpec::new("v", vec![batch, seq_len, embed_dim], DataType::F32),
214 ],
215 outputs: vec![TensorSpec::new(
216 "out",
217 vec![batch, seq_len, embed_dim],
218 DataType::F32,
219 )],
220 }
221 }
222
223 pub fn elementwise(size: usize) -> Self {
225 Self {
226 operation: Operation::Elementwise,
227 dimensions: Dimensions::vector(size),
228 dtype: DataType::F32,
229 inputs: vec![TensorSpec::new("input", vec![size], DataType::F32)],
230 outputs: vec![TensorSpec::new("output", vec![size], DataType::F32)],
231 }
232 }
233
234 pub fn flop_count(&self) -> usize {
236 match self.operation {
237 Operation::Dot => self.dimensions.n * 2, Operation::Matmul => self.dimensions.m * self.dimensions.n * self.dimensions.k * 2,
239 Operation::Attention => {
240 let b = self.dimensions.batch;
241 let s = self.dimensions.seq_len;
242 let h = self.dimensions.num_heads;
243 let d = self.dimensions.head_dim;
244 b * h * (s * s * d * 2 + s * s + s * s * d * 2)
246 }
247 Operation::Elementwise
249 | Operation::Conv2d
250 | Operation::Softmax
251 | Operation::LayerNorm
252 | Operation::Ffn
253 | Operation::Reduce
254 | Operation::Custom(_) => self.dimensions.n,
255 }
256 }
257}