1use crate::op::*;
33use crate::shape;
34use crate::{DType, Graph, NodeId, Op, Shape};
35
36pub trait GraphExt {
38 fn mm(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
40
41 fn add(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
43 fn sub(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
44 fn mul(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
45 fn div(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
46
47 fn gelu(&mut self, x: NodeId) -> NodeId;
49 fn gelu_approx(&mut self, x: NodeId) -> NodeId;
54 fn silu(&mut self, x: NodeId) -> NodeId;
55 fn relu(&mut self, x: NodeId) -> NodeId;
56 fn exp(&mut self, x: NodeId) -> NodeId;
57 fn sqrt(&mut self, x: NodeId) -> NodeId;
58 fn neg(&mut self, x: NodeId) -> NodeId;
59 fn tanh(&mut self, x: NodeId) -> NodeId;
60
61 fn ln(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId;
63 fn layer_norm2d(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId;
64 fn group_norm(
65 &mut self,
66 x: NodeId,
67 gamma: NodeId,
68 beta: NodeId,
69 num_groups: usize,
70 eps: f32,
71 ) -> NodeId;
72 fn rms_norm(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId;
73
74 fn conv2d(
76 &mut self,
77 input: NodeId,
78 weight: NodeId,
79 kernel_size: [usize; 2],
80 stride: [usize; 2],
81 padding: [usize; 2],
82 dilation: [usize; 2],
83 groups: usize,
84 ) -> NodeId;
85 fn conv_transpose2d(
86 &mut self,
87 input: NodeId,
88 weight: NodeId,
89 kernel_size: [usize; 2],
90 stride: [usize; 2],
91 padding: [usize; 2],
92 dilation: [usize; 2],
93 output_padding: [usize; 2],
94 groups: usize,
95 ) -> NodeId;
96
97 fn sum(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId;
99 fn mean(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId;
100 fn sm(&mut self, x: NodeId, axis: i32) -> NodeId;
101
102 fn reshape_(&mut self, x: NodeId, new_shape: Vec<i64>) -> NodeId;
104 fn transpose_(&mut self, x: NodeId, perm: Vec<usize>) -> NodeId;
105 fn narrow_(&mut self, x: NodeId, axis: usize, start: usize, len: usize) -> NodeId;
106 fn concat_(&mut self, inputs: Vec<NodeId>, axis: usize) -> NodeId;
107 fn gather_(&mut self, table: NodeId, indices: NodeId, axis: usize) -> NodeId;
108
109 fn eq(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
111 fn lt(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
112
113 fn attention_(
115 &mut self,
116 q: NodeId,
117 k: NodeId,
118 v: NodeId,
119 mask: NodeId,
120 num_heads: usize,
121 head_dim: usize,
122 ) -> NodeId;
123
124 fn rope(&mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize) -> NodeId;
126 fn rope_n(
128 &mut self,
129 x: NodeId,
130 cos: NodeId,
131 sin: NodeId,
132 head_dim: usize,
133 n_rot: usize,
134 ) -> NodeId;
135
136 fn cast(&mut self, x: NodeId, to: DType) -> NodeId;
138}
139
140impl GraphExt for Graph {
141 fn mm(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
142 let s =
143 shape::matmul_shape(self.shape(lhs), self.shape(rhs)).expect("matmul shape inference");
144 self.matmul(lhs, rhs, s)
145 }
146
147 fn add(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
148 let s = shape::binary_shape(self.shape(lhs), self.shape(rhs)).expect("add shape inference");
149 self.binary(BinaryOp::Add, lhs, rhs, s)
150 }
151
152 fn sub(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
153 let s = shape::binary_shape(self.shape(lhs), self.shape(rhs)).expect("sub shape inference");
154 self.binary(BinaryOp::Sub, lhs, rhs, s)
155 }
156
157 fn mul(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
158 let s = shape::binary_shape(self.shape(lhs), self.shape(rhs)).expect("mul shape inference");
159 self.binary(BinaryOp::Mul, lhs, rhs, s)
160 }
161
162 fn div(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
163 let s = shape::binary_shape(self.shape(lhs), self.shape(rhs)).expect("div shape inference");
164 self.binary(BinaryOp::Div, lhs, rhs, s)
165 }
166
167 fn gelu(&mut self, x: NodeId) -> NodeId {
168 let s = shape::unary_shape(self.shape(x));
169 self.activation(Activation::Gelu, x, s)
170 }
171
172 fn gelu_approx(&mut self, x: NodeId) -> NodeId {
173 let s = shape::unary_shape(self.shape(x));
174 self.activation(Activation::GeluApprox, x, s)
175 }
176
177 fn silu(&mut self, x: NodeId) -> NodeId {
178 let s = shape::unary_shape(self.shape(x));
179 self.activation(Activation::Silu, x, s)
180 }
181
182 fn relu(&mut self, x: NodeId) -> NodeId {
183 let s = shape::unary_shape(self.shape(x));
184 self.activation(Activation::Relu, x, s)
185 }
186
187 fn exp(&mut self, x: NodeId) -> NodeId {
188 let s = shape::unary_shape(self.shape(x));
189 self.activation(Activation::Exp, x, s)
190 }
191
192 fn sqrt(&mut self, x: NodeId) -> NodeId {
193 let s = shape::unary_shape(self.shape(x));
194 self.activation(Activation::Sqrt, x, s)
195 }
196
197 fn neg(&mut self, x: NodeId) -> NodeId {
198 let s = shape::unary_shape(self.shape(x));
199 self.activation(Activation::Neg, x, s)
200 }
201
202 fn tanh(&mut self, x: NodeId) -> NodeId {
203 let s = shape::unary_shape(self.shape(x));
204 self.activation(Activation::Tanh, x, s)
205 }
206
207 fn ln(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
208 let s = shape::unary_shape(self.shape(x));
209 self.layer_norm(x, gamma, beta, -1, eps, s)
210 }
211
212 fn layer_norm2d(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
213 Graph::layer_norm2d(self, x, gamma, beta, eps)
214 }
215
216 fn group_norm(
217 &mut self,
218 x: NodeId,
219 gamma: NodeId,
220 beta: NodeId,
221 num_groups: usize,
222 eps: f32,
223 ) -> NodeId {
224 Graph::group_norm(self, x, gamma, beta, num_groups, eps)
225 }
226
227 fn conv2d(
228 &mut self,
229 input: NodeId,
230 weight: NodeId,
231 kernel_size: [usize; 2],
232 stride: [usize; 2],
233 padding: [usize; 2],
234 dilation: [usize; 2],
235 groups: usize,
236 ) -> NodeId {
237 Graph::conv2d(
238 self,
239 input,
240 weight,
241 kernel_size,
242 stride,
243 padding,
244 dilation,
245 groups,
246 )
247 }
248
249 fn conv_transpose2d(
250 &mut self,
251 input: NodeId,
252 weight: NodeId,
253 kernel_size: [usize; 2],
254 stride: [usize; 2],
255 padding: [usize; 2],
256 dilation: [usize; 2],
257 output_padding: [usize; 2],
258 groups: usize,
259 ) -> NodeId {
260 Graph::conv_transpose2d(
261 self,
262 input,
263 weight,
264 kernel_size,
265 stride,
266 padding,
267 dilation,
268 output_padding,
269 groups,
270 )
271 }
272
273 fn rms_norm(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
274 let s = shape::unary_shape(self.shape(x));
275 self.add_node(Op::RmsNorm { axis: -1, eps }, vec![x, gamma, beta], s)
276 }
277
278 fn sum(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId {
279 let s =
280 shape::reduce_shape(self.shape(x), &axes, keep_dim).expect("reduce shape inference");
281 self.reduce(x, ReduceOp::Sum, axes, keep_dim, s)
282 }
283
284 fn mean(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId {
285 let s =
286 shape::reduce_shape(self.shape(x), &axes, keep_dim).expect("reduce shape inference");
287 self.reduce(x, ReduceOp::Mean, axes, keep_dim, s)
288 }
289
290 fn sm(&mut self, x: NodeId, axis: i32) -> NodeId {
291 let s = shape::softmax_shape(self.shape(x));
292 self.softmax(x, axis, s)
293 }
294
295 fn reshape_(&mut self, x: NodeId, new_shape: Vec<i64>) -> NodeId {
296 let s = shape::reshape_shape(self.shape(x), &new_shape).expect("reshape shape inference");
297 self.reshape(x, new_shape, s)
298 }
299
300 fn transpose_(&mut self, x: NodeId, perm: Vec<usize>) -> NodeId {
301 let s = shape::transpose_shape(self.shape(x), &perm).expect("transpose shape inference");
302 self.add_node(Op::Transpose { perm }, vec![x], s)
303 }
304
305 fn narrow_(&mut self, x: NodeId, axis: usize, start: usize, len: usize) -> NodeId {
306 let s = shape::narrow_shape(self.shape(x), axis, len).expect("narrow shape inference");
307 self.add_node(Op::Narrow { axis, start, len }, vec![x], s)
308 }
309
310 fn concat_(&mut self, inputs: Vec<NodeId>, axis: usize) -> NodeId {
311 let shapes: Vec<&Shape> = inputs.iter().map(|&id| self.shape(id)).collect();
312 let s = shape::concat_shape(&shapes, axis).expect("concat shape inference");
313 self.concat(inputs, axis, s)
314 }
315
316 fn gather_(&mut self, table: NodeId, indices: NodeId, axis: usize) -> NodeId {
317 let s = shape::gather_shape(self.shape(table), self.shape(indices), axis)
318 .expect("gather shape inference");
319 self.gather(table, indices, axis, s)
320 }
321
322 fn eq(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
323 let s = shape::compare_shape(self.shape(lhs), self.shape(rhs))
324 .expect("compare shape inference");
325 self.add_node(Op::Compare(CmpOp::Eq), vec![lhs, rhs], s)
326 }
327
328 fn lt(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
329 let s = shape::compare_shape(self.shape(lhs), self.shape(rhs))
330 .expect("compare shape inference");
331 self.add_node(Op::Compare(CmpOp::Lt), vec![lhs, rhs], s)
332 }
333
334 fn attention_(
335 &mut self,
336 q: NodeId,
337 k: NodeId,
338 v: NodeId,
339 mask: NodeId,
340 num_heads: usize,
341 head_dim: usize,
342 ) -> NodeId {
343 let s = shape::attention_shape(self.shape(q));
344 self.attention(q, k, v, mask, num_heads, head_dim, s)
345 }
346
347 fn rope(&mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize) -> NodeId {
348 self.rope_n(x, cos, sin, head_dim, head_dim)
349 }
350
351 fn rope_n(
352 &mut self,
353 x: NodeId,
354 cos: NodeId,
355 sin: NodeId,
356 head_dim: usize,
357 n_rot: usize,
358 ) -> NodeId {
359 assert!(
360 n_rot <= head_dim && n_rot.is_multiple_of(2),
361 "rope_n: n_rot={n_rot} must be even and <= head_dim={head_dim}"
362 );
363 let s = shape::unary_shape(self.shape(x));
364 self.add_node(Op::Rope { head_dim, n_rot }, vec![x, cos, sin], s)
365 }
366
367 fn cast(&mut self, x: NodeId, to: DType) -> NodeId {
368 let s = shape::cast_shape(self.shape(x), to);
369 self.add_node(Op::Cast { to }, vec![x], s)
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn inferred_conv2d_and_conv_transpose2d() {
379 let mut g = Graph::new("conv");
380 let f = DType::F32;
381 let x = g.input("x", Shape::new(&[1, 4, 8, 8], f));
382 let w = g.param("w", Shape::new(&[8, 2, 3, 3], f));
383 let y = g.conv2d(x, w, [3, 3], [1, 1], [1, 1], [1, 1], 2);
384 assert_eq!(g.shape(y), &Shape::new(&[1, 8, 8, 8], f));
385
386 let wt = g.param("wt", Shape::new(&[4, 8, 2, 2], f));
387 let z = g.conv_transpose2d(x, wt, [2, 2], [2, 2], [0, 0], [1, 1], [0, 0], 1);
388 assert_eq!(g.shape(z), &Shape::new(&[1, 8, 16, 16], f));
389 }
390
391 #[test]
392 fn inferred_layer_norm2d() {
393 let mut g = Graph::new("ln2d");
394 let f = DType::F32;
395 let x = g.input("x", Shape::new(&[1, 4, 8, 8], f));
396 let gamma = g.param("g", Shape::new(&[4], f));
397 let beta = g.param("b", Shape::new(&[4], f));
398 let y = g.layer_norm2d(x, gamma, beta, 1e-6);
399 assert_eq!(g.shape(y), &Shape::new(&[1, 4, 8, 8], f));
400 }
401
402 #[test]
403 fn inferred_matmul_bias_gelu() {
404 let mut g = Graph::new("test");
405 let x = g.input("x", Shape::new(&[4, 15, 384], DType::F32));
406 let w = g.param("w", Shape::new(&[384, 1536], DType::F32));
407 let b = g.param("b", Shape::new(&[1536], DType::F32));
408
409 let mm = g.mm(x, w);
411 let add = g.add(mm, b);
412 let out = g.gelu(add);
413 g.set_outputs(vec![out]);
414
415 assert_eq!(g.shape(mm), &Shape::new(&[4, 15, 1536], DType::F32));
416 assert_eq!(g.shape(add), &Shape::new(&[4, 15, 1536], DType::F32));
417 assert_eq!(g.shape(out), &Shape::new(&[4, 15, 1536], DType::F32));
418 }
419
420 #[test]
421 fn inferred_bert_ffn() {
422 let mut g = Graph::new("bert_ffn");
423 let f = DType::F32;
424 let h = 384;
425 let int = 1536;
426
427 let x = g.input("x", Shape::new(&[4, 15, h], f));
428 let int_w = g.param("int.w", Shape::new(&[h, int], f));
429 let int_b = g.param("int.b", Shape::new(&[int], f));
430 let out_w = g.param("out.w", Shape::new(&[int, h], f));
431 let out_b = g.param("out.b", Shape::new(&[h], f));
432 let gamma = g.param("g", Shape::new(&[h], f));
433 let beta = g.param("b", Shape::new(&[h], f));
434
435 let mm1 = g.mm(x, int_w);
436 let a1 = g.add(mm1, int_b);
437 let ffn = g.gelu(a1);
438 let mm2 = g.mm(ffn, out_w);
439 let out = g.add(mm2, out_b);
440 let res = g.add(out, x);
441 let normed = g.ln(res, gamma, beta, 1e-12);
442 g.set_outputs(vec![normed]);
443
444 assert_eq!(g.shape(ffn), &Shape::new(&[4, 15, int], f));
445 assert_eq!(g.shape(out), &Shape::new(&[4, 15, h], f));
446 assert_eq!(g.shape(normed), &Shape::new(&[4, 15, h], f));
447 }
448
449 #[test]
450 fn inferred_gather_reshape() {
451 let mut g = Graph::new("test");
452 let table = g.param("emb", Shape::new(&[30522, 384], DType::F32));
453 let ids = g.input("ids", Shape::new(&[4, 15], DType::I64));
454
455 let gathered = g.gather_(table, ids, 0);
456 assert_eq!(g.shape(gathered), &Shape::new(&[4, 15, 384], DType::F32));
457
458 let reshaped = g.reshape_(gathered, vec![60, 384]);
459 assert_eq!(g.shape(reshaped), &Shape::new(&[60, 384], DType::F32));
460
461 let transposed = g.transpose_(reshaped, vec![1, 0]);
462 assert_eq!(g.shape(transposed), &Shape::new(&[384, 60], DType::F32));
463 }
464
465 #[test]
466 fn inferred_reduce_softmax() {
467 let mut g = Graph::new("test");
468 let x = g.input("x", Shape::new(&[4, 15, 384], DType::F32));
469
470 let s = g.sm(x, -1);
471 assert_eq!(g.shape(s), &Shape::new(&[4, 15, 384], DType::F32));
472
473 let m = g.mean(x, vec![2], false);
474 assert_eq!(g.shape(m), &Shape::new(&[4, 15], DType::F32));
475
476 let mk = g.mean(x, vec![2], true);
477 assert_eq!(g.shape(mk), &Shape::new(&[4, 15, 1], DType::F32));
478 }
479}