ug/
samples.rs

1pub mod ssa {
2    use crate::lang::ssa;
3    use crate::lang::ssa::{BinaryOp, Const, DType, Instr as I, Kernel, VarId, A};
4    use crate::Result;
5
6    fn arg(index: usize, dtype: DType) -> I {
7        I::DefineGlobal { index, dtype }
8    }
9
10    pub fn simple_add(vec_len: usize) -> Result<Kernel> {
11        let v = VarId::new;
12        let a = |i| A::Var(VarId::new(i));
13        let i32 = |i| A::Const(Const::I32(i));
14        let dtype = DType::I32;
15        let instrs = vec![
16            /* 0 */ arg(0, DType::I32),
17            /* 1 */ arg(1, DType::I32),
18            /* 2 */ arg(2, DType::I32),
19            /* 3 */ I::Range { lo: i32(0), up: i32(vec_len as i32), end_idx: v(8), step: 1 },
20            /* 4 */ I::Load { src: v(1), offset: a(3), dtype },
21            /* 5 */ I::Load { src: v(2), offset: a(3), dtype },
22            /* 6 */ I::Binary { op: self::BinaryOp::Add, lhs: a(4), rhs: a(5), dtype },
23            /* 7 */ I::Store { dst: v(0), offset: a(3), value: a(6), dtype },
24            /* 8 */ I::EndRange { start_idx: v(3) },
25        ];
26        Kernel::from_instrs(instrs)
27    }
28
29    pub fn simple_dotprod(vec_len: usize) -> Result<Kernel> {
30        let v = VarId::new;
31        let a = |i| A::Var(VarId::new(i));
32        let dtype = DType::F32;
33        let instrs = vec![
34            /* 0 */ arg(0, DType::F32),
35            /* 1 */ arg(1, DType::F32),
36            /* 2 */ arg(2, DType::F32),
37            /* 3 */ I::Const(Const::I32(0)),
38            /* 4 */ I::Const(Const::I32(vec_len as i32)),
39            /* 5 */ I::DefineAcc(0f32.try_into()?),
40            /* 6 */ I::Range { lo: a(3), up: a(4), end_idx: v(12), step: 1 },
41            /* 7 */ I::Load { src: v(1), offset: a(6), dtype },
42            /* 8 */ I::Load { src: v(2), offset: a(6), dtype },
43            /* 9 */ I::Binary { op: self::BinaryOp::Mul, lhs: a(7), rhs: a(8), dtype },
44            /* 10*/ I::Binary { op: self::BinaryOp::Add, lhs: a(9), rhs: a(5), dtype },
45            /* 11*/ I::Assign { dst: v(5), src: a(10) },
46            /* 12*/ I::EndRange { start_idx: v(6) },
47            /* 13*/ I::Store { dst: v(0), offset: a(3), value: a(5), dtype },
48        ];
49        Kernel::from_instrs(instrs)
50    }
51
52    pub fn exp(block_size: usize) -> Result<Kernel> {
53        let mut b = crate::block::Block::empty();
54        let dtype = DType::F32;
55        let src_i = b.push(arg(0, DType::F32));
56        let dst_i = b.push(arg(1, DType::F32));
57        let g_i = b.push(I::Special(ssa::Special::BlockIdx));
58        let l_i = b.push(I::Special(ssa::Special::ThreadIdx));
59        let off_i = b.mul(g_i, block_size as i32);
60        let off_i = b.binary(BinaryOp::Add, off_i, l_i, DType::I32);
61        let load_i = b.push(I::Load { src: src_i.to_varid(), offset: off_i.to_a(), dtype });
62        let value_i = b.unary(ssa::UnaryOp::Exp, load_i, dtype);
63        b.push(I::Store {
64            dst: dst_i.to_varid(),
65            offset: off_i.to_a(),
66            value: value_i.to_a(),
67            dtype,
68        });
69        let instrs = b.relocate()?;
70        Kernel::from_instrs(instrs)
71    }
72
73    pub fn exp_block(dim2: usize, block_size: usize) -> Result<Kernel> {
74        if dim2 % block_size != 0 {
75            crate::bail!("last-dim {dim2} must be divisible by block size {block_size}")
76        }
77
78        let mut b = crate::block::Block::empty();
79        let dtype = DType::F32;
80        let src_i = b.push(arg(0, DType::F32));
81        let dst_i = b.push(arg(1, DType::F32));
82        let g_i = b.push(I::Special(ssa::Special::BlockIdx));
83        let l_i = b.push(I::Special(ssa::Special::ThreadIdx));
84        let off_i = b.mul(g_i, dim2 as i32);
85        let off_i = b.binary(BinaryOp::Add, off_i, l_i, DType::I32);
86
87        for i in (0..dim2).step_by(block_size) {
88            let off_i = b.add(off_i, i as i32);
89            let load_i = b.push(I::Load { src: src_i.to_varid(), offset: off_i.to_a(), dtype });
90            let value_i = b.unary(ssa::UnaryOp::Exp, load_i, dtype);
91            b.push(I::Store {
92                dst: dst_i.to_varid(),
93                offset: off_i.to_a(),
94                value: value_i.to_a(),
95                dtype,
96            });
97        }
98        Kernel::from_instrs(b.relocate()?)
99    }
100
101    pub fn softmax_barrier(_dim1: usize, dim2: usize) -> Result<Kernel> {
102        let mut b = crate::block::Block::empty();
103        let dtype = DType::F32;
104        let src_i = b.push(arg(0, DType::F32));
105        let dst_i = b.push(arg(1, DType::F32));
106        let sto_i = b.push(I::DefineLocal { size: (2 * dim2), dtype: DType::F32 }).to_varid();
107        let g_i = b.push(I::Special(ssa::Special::BlockIdx));
108        let l_i = b.push(I::Special(ssa::Special::ThreadIdx));
109        let base_off_i = b.mul(g_i, dim2 as i32);
110        let global_off_i = b.binary(BinaryOp::Add, base_off_i, l_i, DType::I32);
111        let load_i = b.push(I::Load { src: src_i.to_varid(), offset: global_off_i.to_a(), dtype });
112
113        // Compute the max value over dim2.
114        // This implementation uses shared memory which is likely slower than using warp-reduce
115        // primitives but these are harder to scale to more than 32 threads.
116        b.push(I::Store { dst: sto_i, offset: l_i.to_a(), value: load_i.to_a(), dtype });
117        b.push(I::Barrier);
118        let mut offset = 1;
119        let mut v_i = load_i;
120        // This is only ok if dim2 is a power of 2.
121        while offset < dim2 {
122            let off_i = b.add(l_i, offset as i32);
123            let m_i = b.push(I::Load { src: sto_i, offset: off_i.to_a(), dtype });
124            v_i = b.binary(BinaryOp::Max, m_i, v_i, dtype);
125            b.push(I::Store { dst: sto_i, offset: l_i.to_a(), value: v_i.to_a(), dtype });
126            b.push(I::Barrier);
127            offset *= 2
128        }
129        let max_i = b.push(I::Load { src: sto_i, offset: 0.into(), dtype });
130
131        let value_i = b.binary(BinaryOp::Sub, load_i, max_i, dtype);
132        let value_i = b.unary(ssa::UnaryOp::Exp, value_i, dtype);
133
134        // Compute the sum of the exps
135        b.push(I::Store { dst: sto_i, offset: l_i.to_a(), value: value_i.to_a(), dtype });
136        b.push(I::Barrier);
137        let mut offset = 1;
138        let mut v_i = value_i;
139        while offset < dim2 {
140            let off_i = b.add(l_i, offset as i32);
141            let m_i = b.push(I::Load { src: sto_i, offset: off_i.to_a(), dtype });
142            v_i = b.binary(BinaryOp::Add, m_i, v_i, dtype);
143            b.push(I::Store { dst: sto_i, offset: l_i.to_a(), value: v_i.to_a(), dtype });
144            b.push(I::Barrier);
145            offset *= 2
146        }
147        let sum_i = b.push(I::Load { src: sto_i, offset: 0.into(), dtype });
148
149        // Normalize by sum_exp
150        let value_i = b.binary(BinaryOp::Div, value_i, sum_i, dtype);
151        b.push(I::Store {
152            dst: dst_i.to_varid(),
153            offset: global_off_i.to_a(),
154            value: value_i.to_a(),
155            dtype,
156        });
157        Kernel::from_instrs(b.relocate()?)
158    }
159
160    pub fn softmax_reduce(_dim1: usize, dim2: usize) -> Result<Kernel> {
161        let mut b = crate::block::Block::empty();
162        let dtype = DType::F32;
163        let src_i = b.push(arg(0, dtype));
164        let dst_i = b.push(arg(1, dtype));
165        let g_i = b.push(I::Special(ssa::Special::BlockIdx));
166        let l_i = b.push(I::Special(ssa::Special::ThreadIdx));
167        let base_off_i = b.mul(g_i, dim2 as i32);
168        let global_off_i = b.binary(BinaryOp::Add, base_off_i, l_i, DType::I32);
169        let load_i = b.push(I::Load { src: src_i.to_varid(), offset: global_off_i.to_a(), dtype });
170        let max_i = b.push(I::ReduceLocal { op: ssa::ReduceOp::Max, arg: load_i.to_a(), dtype });
171        let value_i = b.binary(BinaryOp::Sub, load_i, max_i, dtype);
172        let value_i = b.unary(ssa::UnaryOp::Exp, value_i, dtype);
173        let sum_i = b.push(I::ReduceLocal { op: ssa::ReduceOp::Sum, arg: value_i.to_a(), dtype });
174        // Normalize by sum_exp
175        let value_i = b.binary(BinaryOp::Div, value_i, sum_i, dtype);
176        b.push(I::Store {
177            dst: dst_i.to_varid(),
178            offset: global_off_i.to_a(),
179            value: value_i.to_a(),
180            dtype,
181        });
182        Kernel::from_instrs(b.relocate()?)
183    }
184
185    pub fn softmax_block(_dim1: usize, dim2: usize, block_size: usize) -> Result<Kernel> {
186        if dim2 % block_size != 0 {
187            crate::bail!("last-dim {dim2} must be divisible by block size {block_size}")
188        }
189        let per_block = dim2 / block_size;
190        let mut b = crate::block::Block::empty();
191        let dtype = DType::F32;
192        let src_i = b.push(arg(0, dtype));
193        let dst_i = b.push(arg(1, dtype));
194        let g_i = b.push(I::Special(ssa::Special::BlockIdx));
195        let l_i = b.push(I::Special(ssa::Special::ThreadIdx));
196        let base_off_i = b.mul(g_i, dim2 as i32);
197        let global_off_i = b.binary(BinaryOp::Add, base_off_i, l_i, DType::I32);
198
199        let mut load_is = Vec::with_capacity(per_block);
200
201        let mut max_i = b.push(I::Const(f32::NEG_INFINITY.try_into()?));
202        for i in (0..dim2).step_by(block_size) {
203            let offset = b.add(global_off_i, i as i32).to_a();
204            let load_i = b.push(I::Load { src: src_i.to_varid(), offset, dtype });
205            max_i = b.binary(BinaryOp::Max, max_i, load_i, dtype);
206            load_is.push(load_i)
207        }
208        let max_i = b.push(I::ReduceLocal { op: ssa::ReduceOp::Max, arg: max_i.to_a(), dtype });
209
210        let mut value_is = Vec::with_capacity(per_block);
211        let mut sum_i = b.push(I::Const(Const::I32(0)));
212        for load_i in load_is.into_iter() {
213            let value_i = b.binary(BinaryOp::Sub, load_i, max_i, dtype);
214            let value_i = b.unary(ssa::UnaryOp::Exp, value_i, dtype);
215            sum_i = b.binary(BinaryOp::Add, value_i, sum_i, dtype);
216            value_is.push(value_i);
217        }
218        let sum_i = b.push(I::ReduceLocal { op: ssa::ReduceOp::Sum, arg: sum_i.to_a(), dtype });
219
220        for (i, value_i) in value_is.into_iter().enumerate() {
221            let i = i * block_size;
222            let offset = b.add(global_off_i, i as i32).to_a();
223            // Normalize by sum_exp
224            let value_i = b.binary(BinaryOp::Div, value_i, sum_i, dtype);
225            b.push(I::Store { dst: dst_i.to_varid(), offset, value: value_i.to_a(), dtype });
226        }
227        Kernel::from_instrs(b.relocate()?)
228    }
229
230    pub fn softmax(dim1: usize, dim2: usize) -> Result<Kernel> {
231        softmax_barrier(dim1, dim2)
232    }
233}
234
235pub mod op {
236    use crate::lang::op::{self, Arg, DType, Kernel, Layout};
237    use crate::Result;
238
239    pub fn softmax(dim1: usize, dim2: usize) -> Result<Kernel> {
240        let layout = Layout::from_shape(&[dim1, dim2]);
241        let src_ptr = Arg::ptr(DType::F32);
242        let dst_ptr = Arg::ptr(DType::F32);
243        let src = op::load(src_ptr.id(), layout.clone(), DType::F32)?;
244        let src_max = op::reduce(op::ReduceOp::Max, src.clone(), 1)?;
245        let src_max = op::broadcast(src_max, (dim1, dim2))?;
246        let diff = op::binary(op::BinaryOp::Sub, src, src_max)?;
247        let exp = op::unary(op::UnaryOp::Exp, diff)?;
248        let sum_exp = op::reduce(op::ReduceOp::Sum, exp.clone(), 1)?;
249        let sum_exp = op::broadcast(sum_exp, (dim1, dim2))?;
250        let sm = op::binary(op::BinaryOp::Div, exp, sum_exp)?;
251        let st = op::store(dst_ptr.id(), layout, sm)?;
252        let kernel =
253            Kernel::new(format!("softmax_{dim1}_{dim2}"), vec![src_ptr, dst_ptr], vec![st]);
254        Ok(kernel)
255    }
256}
257
258use crate::lang::{Arg, DType, ExprNode as E, IndexExprNode as I, Kernel, Ops};
259
260pub fn simple_add(block_size: usize) -> crate::Result<Kernel> {
261    let lhs_ptr = Arg::ptr(DType::F32);
262    let rhs_ptr = Arg::ptr(DType::F32);
263    let dst_ptr = Arg::ptr(DType::F32);
264    let offset = I::mul(&I::program_id(), &I::cst(block_size));
265    let stride = I::cst(1);
266    let len = I::cst(block_size);
267    let lhs = E::load(&lhs_ptr, &offset, &len, &stride)?;
268    let rhs = E::load(&rhs_ptr, &offset, &len, &stride)?;
269    let op = Ops::store(&dst_ptr, &offset, &len, &stride, &lhs.add(&rhs));
270    let k =
271        Kernel::new(format!("simple_add_{block_size}"), vec![lhs_ptr, rhs_ptr, dst_ptr], vec![op]);
272    Ok(k)
273}