Skip to main content

cbtop/grammar/
workload.rs

1//! Workload specification (Data equivalent in Grammar of Graphics).
2
3/// Compute operation type
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5pub enum Operation {
6    /// Element-wise operations
7    Elementwise,
8    /// Dot product
9    Dot,
10    /// Matrix multiplication
11    Matmul,
12    /// 2D convolution
13    Conv2d,
14    /// Multi-head attention
15    Attention,
16    /// Softmax
17    Softmax,
18    /// Layer normalization
19    LayerNorm,
20    /// Feed-forward network
21    Ffn,
22    /// Reduction (sum, mean, max, etc.)
23    Reduce,
24    /// Custom operation
25    Custom(u32),
26}
27
28/// Data type for compute operations
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub enum DataType {
31    /// 32-bit floating point
32    F32,
33    /// 16-bit floating point
34    F16,
35    /// Brain float 16
36    Bf16,
37    /// 8-bit integer
38    I8,
39    /// Unsigned 8-bit integer
40    U8,
41    /// 4-bit quantized (packed)
42    Q4,
43}
44
45impl DataType {
46    /// Get byte size of data type
47    pub fn byte_size(&self) -> usize {
48        match self {
49            DataType::F32 => 4,
50            DataType::F16 | DataType::Bf16 => 2,
51            DataType::I8 | DataType::U8 => 1,
52            DataType::Q4 => 1, // 2 values per byte
53        }
54    }
55}
56
57/// Problem dimensions
58#[derive(Debug, Clone, PartialEq)]
59pub struct Dimensions {
60    /// Batch size
61    pub batch: usize,
62    /// Sequence length (for attention)
63    pub seq_len: usize,
64    /// Number of heads (for attention)
65    pub num_heads: usize,
66    /// Head dimension (for attention)
67    pub head_dim: usize,
68    /// Hidden dimension (for FFN)
69    pub hidden_dim: usize,
70    /// M dimension (for matmul)
71    pub m: usize,
72    /// N dimension (for matmul)
73    pub n: usize,
74    /// K dimension (for matmul)
75    pub k: usize,
76}
77
78impl Default for Dimensions {
79    fn default() -> Self {
80        Self {
81            batch: 1,
82            seq_len: 1,
83            num_heads: 1,
84            head_dim: 64,
85            hidden_dim: 1,
86            m: 1,
87            n: 1,
88            k: 1,
89        }
90    }
91}
92
93impl Dimensions {
94    /// Create dimensions for vector operation
95    pub fn vector(size: usize) -> Self {
96        Self {
97            n: size,
98            ..Default::default()
99        }
100    }
101
102    /// Create dimensions for matrix multiplication
103    pub fn matmul(m: usize, n: usize, k: usize) -> Self {
104        Self {
105            m,
106            n,
107            k,
108            ..Default::default()
109        }
110    }
111
112    /// Create dimensions for attention
113    pub fn attention(batch: usize, seq_len: usize, num_heads: usize, head_dim: usize) -> Self {
114        Self {
115            batch,
116            seq_len,
117            num_heads,
118            head_dim,
119            ..Default::default()
120        }
121    }
122}
123
124/// Tensor specification
125#[derive(Debug, Clone, PartialEq)]
126pub struct TensorSpec {
127    /// Tensor name/identifier
128    pub name: String,
129    /// Shape dimensions
130    pub shape: Vec<usize>,
131    /// Data type
132    pub dtype: DataType,
133    /// Stride (optional, for non-contiguous tensors)
134    pub stride: Option<Vec<usize>>,
135}
136
137impl TensorSpec {
138    /// Create a new tensor spec
139    pub fn new(name: impl Into<String>, shape: Vec<usize>, dtype: DataType) -> Self {
140        Self {
141            name: name.into(),
142            shape,
143            dtype,
144            stride: None,
145        }
146    }
147
148    /// Total number of elements
149    pub fn numel(&self) -> usize {
150        self.shape.iter().product()
151    }
152
153    /// Total byte size
154    pub fn byte_size(&self) -> usize {
155        self.numel() * self.dtype.byte_size()
156    }
157}
158
159/// Workload specification (analogous to DataFrame)
160#[derive(Debug, Clone, PartialEq)]
161pub struct WorkloadSpec {
162    /// Operation type
163    pub operation: Operation,
164    /// Problem dimensions
165    pub dimensions: Dimensions,
166    /// Primary data type
167    pub dtype: DataType,
168    /// Input tensor specifications
169    pub inputs: Vec<TensorSpec>,
170    /// Output tensor specifications
171    pub outputs: Vec<TensorSpec>,
172}
173
174impl WorkloadSpec {
175    /// Create a dot product workload
176    pub fn dot(size: usize) -> Self {
177        Self {
178            operation: Operation::Dot,
179            dimensions: Dimensions::vector(size),
180            dtype: DataType::F32,
181            inputs: vec![
182                TensorSpec::new("a", vec![size], DataType::F32),
183                TensorSpec::new("b", vec![size], DataType::F32),
184            ],
185            outputs: vec![TensorSpec::new("result", vec![1], DataType::F32)],
186        }
187    }
188
189    /// Create a matrix multiplication workload
190    pub fn matmul(m: usize, n: usize, k: usize) -> Self {
191        Self {
192            operation: Operation::Matmul,
193            dimensions: Dimensions::matmul(m, n, k),
194            dtype: DataType::F32,
195            inputs: vec![
196                TensorSpec::new("a", vec![m, k], DataType::F32),
197                TensorSpec::new("b", vec![k, n], DataType::F32),
198            ],
199            outputs: vec![TensorSpec::new("c", vec![m, n], DataType::F32)],
200        }
201    }
202
203    /// Create an attention workload
204    pub fn attention(batch: usize, seq_len: usize, num_heads: usize, head_dim: usize) -> Self {
205        let embed_dim = num_heads * head_dim;
206        Self {
207            operation: Operation::Attention,
208            dimensions: Dimensions::attention(batch, seq_len, num_heads, head_dim),
209            dtype: DataType::F32,
210            inputs: vec![
211                TensorSpec::new("q", vec![batch, seq_len, embed_dim], DataType::F32),
212                TensorSpec::new("k", vec![batch, seq_len, embed_dim], DataType::F32),
213                TensorSpec::new("v", vec![batch, seq_len, embed_dim], DataType::F32),
214            ],
215            outputs: vec![TensorSpec::new(
216                "out",
217                vec![batch, seq_len, embed_dim],
218                DataType::F32,
219            )],
220        }
221    }
222
223    /// Create an elementwise workload
224    pub fn elementwise(size: usize) -> Self {
225        Self {
226            operation: Operation::Elementwise,
227            dimensions: Dimensions::vector(size),
228            dtype: DataType::F32,
229            inputs: vec![TensorSpec::new("input", vec![size], DataType::F32)],
230            outputs: vec![TensorSpec::new("output", vec![size], DataType::F32)],
231        }
232    }
233
234    /// Total FLOP count estimate
235    pub fn flop_count(&self) -> usize {
236        match self.operation {
237            Operation::Dot => self.dimensions.n * 2, // mul + add
238            Operation::Matmul => self.dimensions.m * self.dimensions.n * self.dimensions.k * 2,
239            Operation::Attention => {
240                let b = self.dimensions.batch;
241                let s = self.dimensions.seq_len;
242                let h = self.dimensions.num_heads;
243                let d = self.dimensions.head_dim;
244                // QK^T + softmax + AV
245                b * h * (s * s * d * 2 + s * s + s * s * d * 2)
246            }
247            // Default estimate for remaining operations
248            Operation::Elementwise
249            | Operation::Conv2d
250            | Operation::Softmax
251            | Operation::LayerNorm
252            | Operation::Ffn
253            | Operation::Reduce
254            | Operation::Custom(_) => self.dimensions.n,
255        }
256    }
257}