1pub mod ssa {
2 use crate::lang::ssa;
3 use crate::lang::ssa::{BinaryOp, Const, DType, Instr as I, Kernel, VarId, A};
4 use crate::Result;
5
6 fn arg(index: usize, dtype: DType) -> I {
7 I::DefineGlobal { index, dtype }
8 }
9
10 pub fn simple_add(vec_len: usize) -> Result<Kernel> {
11 let v = VarId::new;
12 let a = |i| A::Var(VarId::new(i));
13 let i32 = |i| A::Const(Const::I32(i));
14 let dtype = DType::I32;
15 let instrs = vec![
16 arg(0, DType::I32),
17 arg(1, DType::I32),
18 arg(2, DType::I32),
19 I::Range { lo: i32(0), up: i32(vec_len as i32), end_idx: v(8), step: 1 },
20 I::Load { src: v(1), offset: a(3), dtype },
21 I::Load { src: v(2), offset: a(3), dtype },
22 I::Binary { op: self::BinaryOp::Add, lhs: a(4), rhs: a(5), dtype },
23 I::Store { dst: v(0), offset: a(3), value: a(6), dtype },
24 I::EndRange { start_idx: v(3) },
25 ];
26 Kernel::from_instrs(instrs)
27 }
28
29 pub fn simple_dotprod(vec_len: usize) -> Result<Kernel> {
30 let v = VarId::new;
31 let a = |i| A::Var(VarId::new(i));
32 let dtype = DType::F32;
33 let instrs = vec![
34 arg(0, DType::F32),
35 arg(1, DType::F32),
36 arg(2, DType::F32),
37 I::Const(Const::I32(0)),
38 I::Const(Const::I32(vec_len as i32)),
39 I::DefineAcc(0f32.try_into()?),
40 I::Range { lo: a(3), up: a(4), end_idx: v(12), step: 1 },
41 I::Load { src: v(1), offset: a(6), dtype },
42 I::Load { src: v(2), offset: a(6), dtype },
43 I::Binary { op: self::BinaryOp::Mul, lhs: a(7), rhs: a(8), dtype },
44 I::Binary { op: self::BinaryOp::Add, lhs: a(9), rhs: a(5), dtype },
45 I::Assign { dst: v(5), src: a(10) },
46 I::EndRange { start_idx: v(6) },
47 I::Store { dst: v(0), offset: a(3), value: a(5), dtype },
48 ];
49 Kernel::from_instrs(instrs)
50 }
51
52 pub fn exp(block_size: usize) -> Result<Kernel> {
53 let mut b = crate::block::Block::empty();
54 let dtype = DType::F32;
55 let src_i = b.push(arg(0, DType::F32));
56 let dst_i = b.push(arg(1, DType::F32));
57 let g_i = b.push(I::Special(ssa::Special::BlockIdx));
58 let l_i = b.push(I::Special(ssa::Special::ThreadIdx));
59 let off_i = b.mul(g_i, block_size as i32);
60 let off_i = b.binary(BinaryOp::Add, off_i, l_i, DType::I32);
61 let load_i = b.push(I::Load { src: src_i.to_varid(), offset: off_i.to_a(), dtype });
62 let value_i = b.unary(ssa::UnaryOp::Exp, load_i, dtype);
63 b.push(I::Store {
64 dst: dst_i.to_varid(),
65 offset: off_i.to_a(),
66 value: value_i.to_a(),
67 dtype,
68 });
69 let instrs = b.relocate()?;
70 Kernel::from_instrs(instrs)
71 }
72
73 pub fn exp_block(dim2: usize, block_size: usize) -> Result<Kernel> {
74 if dim2 % block_size != 0 {
75 crate::bail!("last-dim {dim2} must be divisible by block size {block_size}")
76 }
77
78 let mut b = crate::block::Block::empty();
79 let dtype = DType::F32;
80 let src_i = b.push(arg(0, DType::F32));
81 let dst_i = b.push(arg(1, DType::F32));
82 let g_i = b.push(I::Special(ssa::Special::BlockIdx));
83 let l_i = b.push(I::Special(ssa::Special::ThreadIdx));
84 let off_i = b.mul(g_i, dim2 as i32);
85 let off_i = b.binary(BinaryOp::Add, off_i, l_i, DType::I32);
86
87 for i in (0..dim2).step_by(block_size) {
88 let off_i = b.add(off_i, i as i32);
89 let load_i = b.push(I::Load { src: src_i.to_varid(), offset: off_i.to_a(), dtype });
90 let value_i = b.unary(ssa::UnaryOp::Exp, load_i, dtype);
91 b.push(I::Store {
92 dst: dst_i.to_varid(),
93 offset: off_i.to_a(),
94 value: value_i.to_a(),
95 dtype,
96 });
97 }
98 Kernel::from_instrs(b.relocate()?)
99 }
100
101 pub fn softmax_barrier(_dim1: usize, dim2: usize) -> Result<Kernel> {
102 let mut b = crate::block::Block::empty();
103 let dtype = DType::F32;
104 let src_i = b.push(arg(0, DType::F32));
105 let dst_i = b.push(arg(1, DType::F32));
106 let sto_i = b.push(I::DefineLocal { size: (2 * dim2), dtype: DType::F32 }).to_varid();
107 let g_i = b.push(I::Special(ssa::Special::BlockIdx));
108 let l_i = b.push(I::Special(ssa::Special::ThreadIdx));
109 let base_off_i = b.mul(g_i, dim2 as i32);
110 let global_off_i = b.binary(BinaryOp::Add, base_off_i, l_i, DType::I32);
111 let load_i = b.push(I::Load { src: src_i.to_varid(), offset: global_off_i.to_a(), dtype });
112
113 b.push(I::Store { dst: sto_i, offset: l_i.to_a(), value: load_i.to_a(), dtype });
117 b.push(I::Barrier);
118 let mut offset = 1;
119 let mut v_i = load_i;
120 while offset < dim2 {
122 let off_i = b.add(l_i, offset as i32);
123 let m_i = b.push(I::Load { src: sto_i, offset: off_i.to_a(), dtype });
124 v_i = b.binary(BinaryOp::Max, m_i, v_i, dtype);
125 b.push(I::Store { dst: sto_i, offset: l_i.to_a(), value: v_i.to_a(), dtype });
126 b.push(I::Barrier);
127 offset *= 2
128 }
129 let max_i = b.push(I::Load { src: sto_i, offset: 0.into(), dtype });
130
131 let value_i = b.binary(BinaryOp::Sub, load_i, max_i, dtype);
132 let value_i = b.unary(ssa::UnaryOp::Exp, value_i, dtype);
133
134 b.push(I::Store { dst: sto_i, offset: l_i.to_a(), value: value_i.to_a(), dtype });
136 b.push(I::Barrier);
137 let mut offset = 1;
138 let mut v_i = value_i;
139 while offset < dim2 {
140 let off_i = b.add(l_i, offset as i32);
141 let m_i = b.push(I::Load { src: sto_i, offset: off_i.to_a(), dtype });
142 v_i = b.binary(BinaryOp::Add, m_i, v_i, dtype);
143 b.push(I::Store { dst: sto_i, offset: l_i.to_a(), value: v_i.to_a(), dtype });
144 b.push(I::Barrier);
145 offset *= 2
146 }
147 let sum_i = b.push(I::Load { src: sto_i, offset: 0.into(), dtype });
148
149 let value_i = b.binary(BinaryOp::Div, value_i, sum_i, dtype);
151 b.push(I::Store {
152 dst: dst_i.to_varid(),
153 offset: global_off_i.to_a(),
154 value: value_i.to_a(),
155 dtype,
156 });
157 Kernel::from_instrs(b.relocate()?)
158 }
159
160 pub fn softmax_reduce(_dim1: usize, dim2: usize) -> Result<Kernel> {
161 let mut b = crate::block::Block::empty();
162 let dtype = DType::F32;
163 let src_i = b.push(arg(0, dtype));
164 let dst_i = b.push(arg(1, dtype));
165 let g_i = b.push(I::Special(ssa::Special::BlockIdx));
166 let l_i = b.push(I::Special(ssa::Special::ThreadIdx));
167 let base_off_i = b.mul(g_i, dim2 as i32);
168 let global_off_i = b.binary(BinaryOp::Add, base_off_i, l_i, DType::I32);
169 let load_i = b.push(I::Load { src: src_i.to_varid(), offset: global_off_i.to_a(), dtype });
170 let max_i = b.push(I::ReduceLocal { op: ssa::ReduceOp::Max, arg: load_i.to_a(), dtype });
171 let value_i = b.binary(BinaryOp::Sub, load_i, max_i, dtype);
172 let value_i = b.unary(ssa::UnaryOp::Exp, value_i, dtype);
173 let sum_i = b.push(I::ReduceLocal { op: ssa::ReduceOp::Sum, arg: value_i.to_a(), dtype });
174 let value_i = b.binary(BinaryOp::Div, value_i, sum_i, dtype);
176 b.push(I::Store {
177 dst: dst_i.to_varid(),
178 offset: global_off_i.to_a(),
179 value: value_i.to_a(),
180 dtype,
181 });
182 Kernel::from_instrs(b.relocate()?)
183 }
184
185 pub fn softmax_block(_dim1: usize, dim2: usize, block_size: usize) -> Result<Kernel> {
186 if dim2 % block_size != 0 {
187 crate::bail!("last-dim {dim2} must be divisible by block size {block_size}")
188 }
189 let per_block = dim2 / block_size;
190 let mut b = crate::block::Block::empty();
191 let dtype = DType::F32;
192 let src_i = b.push(arg(0, dtype));
193 let dst_i = b.push(arg(1, dtype));
194 let g_i = b.push(I::Special(ssa::Special::BlockIdx));
195 let l_i = b.push(I::Special(ssa::Special::ThreadIdx));
196 let base_off_i = b.mul(g_i, dim2 as i32);
197 let global_off_i = b.binary(BinaryOp::Add, base_off_i, l_i, DType::I32);
198
199 let mut load_is = Vec::with_capacity(per_block);
200
201 let mut max_i = b.push(I::Const(f32::NEG_INFINITY.try_into()?));
202 for i in (0..dim2).step_by(block_size) {
203 let offset = b.add(global_off_i, i as i32).to_a();
204 let load_i = b.push(I::Load { src: src_i.to_varid(), offset, dtype });
205 max_i = b.binary(BinaryOp::Max, max_i, load_i, dtype);
206 load_is.push(load_i)
207 }
208 let max_i = b.push(I::ReduceLocal { op: ssa::ReduceOp::Max, arg: max_i.to_a(), dtype });
209
210 let mut value_is = Vec::with_capacity(per_block);
211 let mut sum_i = b.push(I::Const(Const::I32(0)));
212 for load_i in load_is.into_iter() {
213 let value_i = b.binary(BinaryOp::Sub, load_i, max_i, dtype);
214 let value_i = b.unary(ssa::UnaryOp::Exp, value_i, dtype);
215 sum_i = b.binary(BinaryOp::Add, value_i, sum_i, dtype);
216 value_is.push(value_i);
217 }
218 let sum_i = b.push(I::ReduceLocal { op: ssa::ReduceOp::Sum, arg: sum_i.to_a(), dtype });
219
220 for (i, value_i) in value_is.into_iter().enumerate() {
221 let i = i * block_size;
222 let offset = b.add(global_off_i, i as i32).to_a();
223 let value_i = b.binary(BinaryOp::Div, value_i, sum_i, dtype);
225 b.push(I::Store { dst: dst_i.to_varid(), offset, value: value_i.to_a(), dtype });
226 }
227 Kernel::from_instrs(b.relocate()?)
228 }
229
230 pub fn softmax(dim1: usize, dim2: usize) -> Result<Kernel> {
231 softmax_barrier(dim1, dim2)
232 }
233}
234
235pub mod op {
236 use crate::lang::op::{self, Arg, DType, Kernel, Layout};
237 use crate::Result;
238
239 pub fn softmax(dim1: usize, dim2: usize) -> Result<Kernel> {
240 let layout = Layout::from_shape(&[dim1, dim2]);
241 let src_ptr = Arg::ptr(DType::F32);
242 let dst_ptr = Arg::ptr(DType::F32);
243 let src = op::load(src_ptr.id(), layout.clone(), DType::F32)?;
244 let src_max = op::reduce(op::ReduceOp::Max, src.clone(), 1)?;
245 let src_max = op::broadcast(src_max, (dim1, dim2))?;
246 let diff = op::binary(op::BinaryOp::Sub, src, src_max)?;
247 let exp = op::unary(op::UnaryOp::Exp, diff)?;
248 let sum_exp = op::reduce(op::ReduceOp::Sum, exp.clone(), 1)?;
249 let sum_exp = op::broadcast(sum_exp, (dim1, dim2))?;
250 let sm = op::binary(op::BinaryOp::Div, exp, sum_exp)?;
251 let st = op::store(dst_ptr.id(), layout, sm)?;
252 let kernel =
253 Kernel::new(format!("softmax_{dim1}_{dim2}"), vec![src_ptr, dst_ptr], vec![st]);
254 Ok(kernel)
255 }
256}
257
258use crate::lang::{Arg, DType, ExprNode as E, IndexExprNode as I, Kernel, Ops};
259
260pub fn simple_add(block_size: usize) -> crate::Result<Kernel> {
261 let lhs_ptr = Arg::ptr(DType::F32);
262 let rhs_ptr = Arg::ptr(DType::F32);
263 let dst_ptr = Arg::ptr(DType::F32);
264 let offset = I::mul(&I::program_id(), &I::cst(block_size));
265 let stride = I::cst(1);
266 let len = I::cst(block_size);
267 let lhs = E::load(&lhs_ptr, &offset, &len, &stride)?;
268 let rhs = E::load(&rhs_ptr, &offset, &len, &stride)?;
269 let op = Ops::store(&dst_ptr, &offset, &len, &stride, &lhs.add(&rhs));
270 let k =
271 Kernel::new(format!("simple_add_{block_size}"), vec![lhs_ptr, rhs_ptr, dst_ptr], vec![op]);
272 Ok(k)
273}