1use crate::dtype::scalar_constant_bytes;
37use crate::op::*;
38use crate::shape;
39use crate::{DType, Graph, NodeId, Op, Shape};
40
41pub trait GraphExt {
43 fn mm(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
45
46 fn add(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
48 fn sub(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
49 fn mul(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
50 fn div(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
51
52 fn gelu(&mut self, x: NodeId) -> NodeId;
54 fn gelu_approx(&mut self, x: NodeId) -> NodeId;
59 fn silu(&mut self, x: NodeId) -> NodeId;
60 fn relu(&mut self, x: NodeId) -> NodeId;
61 fn exp(&mut self, x: NodeId) -> NodeId;
62 fn sqrt(&mut self, x: NodeId) -> NodeId;
63 fn neg(&mut self, x: NodeId) -> NodeId;
64 fn tanh(&mut self, x: NodeId) -> NodeId;
65
66 fn ln(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId;
68 fn layer_norm2d(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId;
69 fn group_norm(
70 &mut self,
71 x: NodeId,
72 gamma: NodeId,
73 beta: NodeId,
74 num_groups: usize,
75 eps: f32,
76 ) -> NodeId;
77 fn rms_norm(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId;
78
79 fn conv2d(
81 &mut self,
82 input: NodeId,
83 weight: NodeId,
84 kernel_size: [usize; 2],
85 stride: [usize; 2],
86 padding: [usize; 2],
87 dilation: [usize; 2],
88 groups: usize,
89 ) -> NodeId;
90 fn conv_transpose2d(
91 &mut self,
92 input: NodeId,
93 weight: NodeId,
94 kernel_size: [usize; 2],
95 stride: [usize; 2],
96 padding: [usize; 2],
97 dilation: [usize; 2],
98 output_padding: [usize; 2],
99 groups: usize,
100 ) -> NodeId;
101
102 fn sum(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId;
104 fn mean(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId;
105 fn sm(&mut self, x: NodeId, axis: i32) -> NodeId;
106
107 fn reshape_(&mut self, x: NodeId, new_shape: Vec<i64>) -> NodeId;
109 fn transpose_(&mut self, x: NodeId, perm: Vec<usize>) -> NodeId;
110 fn narrow_(&mut self, x: NodeId, axis: usize, start: usize, len: usize) -> NodeId;
111 fn concat_(&mut self, inputs: Vec<NodeId>, axis: usize) -> NodeId;
112 fn gather_(&mut self, table: NodeId, indices: NodeId, axis: usize) -> NodeId;
113
114 fn eq(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
116 fn lt(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
117
118 fn attention_(
120 &mut self,
121 q: NodeId,
122 k: NodeId,
123 v: NodeId,
124 mask: NodeId,
125 num_heads: usize,
126 head_dim: usize,
127 ) -> NodeId;
128
129 fn rope(&mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize) -> NodeId;
131 fn rope_n(
133 &mut self,
134 x: NodeId,
135 cos: NodeId,
136 sin: NodeId,
137 head_dim: usize,
138 n_rot: usize,
139 ) -> NodeId;
140
141 fn cast(&mut self, x: NodeId, to: DType) -> NodeId;
143
144 fn constant(&mut self, value: f64, dtype: DType) -> NodeId;
148
149 fn try_constant(&mut self, value: f64, dtype: DType) -> Result<NodeId, String>;
154
155 fn stop_gradient(&mut self, x: NodeId) -> NodeId;
161}
162
163impl GraphExt for Graph {
164 fn mm(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
165 let s =
166 shape::matmul_shape(self.shape(lhs), self.shape(rhs)).expect("matmul shape inference");
167 self.matmul(lhs, rhs, s)
168 }
169
170 fn add(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
171 let s = shape::binary_shape(self.shape(lhs), self.shape(rhs)).expect("add shape inference");
172 self.binary(BinaryOp::Add, lhs, rhs, s)
173 }
174
175 fn sub(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
176 let s = shape::binary_shape(self.shape(lhs), self.shape(rhs)).expect("sub shape inference");
177 self.binary(BinaryOp::Sub, lhs, rhs, s)
178 }
179
180 fn mul(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
181 let s = shape::binary_shape(self.shape(lhs), self.shape(rhs)).expect("mul shape inference");
182 self.binary(BinaryOp::Mul, lhs, rhs, s)
183 }
184
185 fn div(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
186 let s = shape::binary_shape(self.shape(lhs), self.shape(rhs)).expect("div shape inference");
187 self.binary(BinaryOp::Div, lhs, rhs, s)
188 }
189
190 fn gelu(&mut self, x: NodeId) -> NodeId {
191 let s = shape::unary_shape(self.shape(x));
192 self.activation(Activation::Gelu, x, s)
193 }
194
195 fn gelu_approx(&mut self, x: NodeId) -> NodeId {
196 let s = shape::unary_shape(self.shape(x));
197 self.activation(Activation::GeluApprox, x, s)
198 }
199
200 fn silu(&mut self, x: NodeId) -> NodeId {
201 let s = shape::unary_shape(self.shape(x));
202 self.activation(Activation::Silu, x, s)
203 }
204
205 fn relu(&mut self, x: NodeId) -> NodeId {
206 let s = shape::unary_shape(self.shape(x));
207 self.activation(Activation::Relu, x, s)
208 }
209
210 fn exp(&mut self, x: NodeId) -> NodeId {
211 let s = shape::unary_shape(self.shape(x));
212 self.activation(Activation::Exp, x, s)
213 }
214
215 fn sqrt(&mut self, x: NodeId) -> NodeId {
216 let s = shape::unary_shape(self.shape(x));
217 self.activation(Activation::Sqrt, x, s)
218 }
219
220 fn neg(&mut self, x: NodeId) -> NodeId {
221 let s = shape::unary_shape(self.shape(x));
222 self.activation(Activation::Neg, x, s)
223 }
224
225 fn tanh(&mut self, x: NodeId) -> NodeId {
226 let s = shape::unary_shape(self.shape(x));
227 self.activation(Activation::Tanh, x, s)
228 }
229
230 fn ln(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
231 let s = shape::unary_shape(self.shape(x));
232 self.layer_norm(x, gamma, beta, -1, eps, s)
233 }
234
235 fn layer_norm2d(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
236 Graph::layer_norm2d(self, x, gamma, beta, eps)
237 }
238
239 fn group_norm(
240 &mut self,
241 x: NodeId,
242 gamma: NodeId,
243 beta: NodeId,
244 num_groups: usize,
245 eps: f32,
246 ) -> NodeId {
247 Graph::group_norm(self, x, gamma, beta, num_groups, eps)
248 }
249
250 fn conv2d(
251 &mut self,
252 input: NodeId,
253 weight: NodeId,
254 kernel_size: [usize; 2],
255 stride: [usize; 2],
256 padding: [usize; 2],
257 dilation: [usize; 2],
258 groups: usize,
259 ) -> NodeId {
260 Graph::conv2d(
261 self,
262 input,
263 weight,
264 kernel_size,
265 stride,
266 padding,
267 dilation,
268 groups,
269 )
270 }
271
272 fn conv_transpose2d(
273 &mut self,
274 input: NodeId,
275 weight: NodeId,
276 kernel_size: [usize; 2],
277 stride: [usize; 2],
278 padding: [usize; 2],
279 dilation: [usize; 2],
280 output_padding: [usize; 2],
281 groups: usize,
282 ) -> NodeId {
283 Graph::conv_transpose2d(
284 self,
285 input,
286 weight,
287 kernel_size,
288 stride,
289 padding,
290 dilation,
291 output_padding,
292 groups,
293 )
294 }
295
296 fn rms_norm(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
297 let s = shape::unary_shape(self.shape(x));
298 self.add_node(Op::RmsNorm { axis: -1, eps }, vec![x, gamma, beta], s)
299 }
300
301 fn sum(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId {
302 let s =
303 shape::reduce_shape(self.shape(x), &axes, keep_dim).expect("reduce shape inference");
304 self.reduce(x, ReduceOp::Sum, axes, keep_dim, s)
305 }
306
307 fn mean(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId {
308 let s =
309 shape::reduce_shape(self.shape(x), &axes, keep_dim).expect("reduce shape inference");
310 self.reduce(x, ReduceOp::Mean, axes, keep_dim, s)
311 }
312
313 fn sm(&mut self, x: NodeId, axis: i32) -> NodeId {
314 let s = shape::softmax_shape(self.shape(x));
315 self.softmax(x, axis, s)
316 }
317
318 fn reshape_(&mut self, x: NodeId, new_shape: Vec<i64>) -> NodeId {
319 let s = shape::reshape_shape(self.shape(x), &new_shape).expect("reshape shape inference");
320 self.reshape(x, new_shape, s)
321 }
322
323 fn transpose_(&mut self, x: NodeId, perm: Vec<usize>) -> NodeId {
324 let s = shape::transpose_shape(self.shape(x), &perm).expect("transpose shape inference");
325 self.add_node(Op::Transpose { perm }, vec![x], s)
326 }
327
328 fn narrow_(&mut self, x: NodeId, axis: usize, start: usize, len: usize) -> NodeId {
329 let s = shape::narrow_shape(self.shape(x), axis, len).expect("narrow shape inference");
330 self.add_node(Op::Narrow { axis, start, len }, vec![x], s)
331 }
332
333 fn concat_(&mut self, inputs: Vec<NodeId>, axis: usize) -> NodeId {
334 let shapes: Vec<&Shape> = inputs.iter().map(|&id| self.shape(id)).collect();
335 let s = shape::concat_shape(&shapes, axis).expect("concat shape inference");
336 self.concat(inputs, axis, s)
337 }
338
339 fn gather_(&mut self, table: NodeId, indices: NodeId, axis: usize) -> NodeId {
340 let s = shape::gather_shape(self.shape(table), self.shape(indices), axis)
341 .expect("gather shape inference");
342 self.gather(table, indices, axis, s)
343 }
344
345 fn eq(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
346 let s = shape::compare_shape(self.shape(lhs), self.shape(rhs))
347 .expect("compare shape inference");
348 self.add_node(Op::Compare(CmpOp::Eq), vec![lhs, rhs], s)
349 }
350
351 fn lt(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
352 let s = shape::compare_shape(self.shape(lhs), self.shape(rhs))
353 .expect("compare shape inference");
354 self.add_node(Op::Compare(CmpOp::Lt), vec![lhs, rhs], s)
355 }
356
357 fn attention_(
358 &mut self,
359 q: NodeId,
360 k: NodeId,
361 v: NodeId,
362 mask: NodeId,
363 num_heads: usize,
364 head_dim: usize,
365 ) -> NodeId {
366 let s = shape::attention_shape(self.shape(q));
367 self.attention(q, k, v, mask, num_heads, head_dim, s)
368 }
369
370 fn rope(&mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize) -> NodeId {
371 self.rope_n(x, cos, sin, head_dim, head_dim)
372 }
373
374 fn rope_n(
375 &mut self,
376 x: NodeId,
377 cos: NodeId,
378 sin: NodeId,
379 head_dim: usize,
380 n_rot: usize,
381 ) -> NodeId {
382 assert!(
383 n_rot <= head_dim && n_rot.is_multiple_of(2),
384 "rope_n: n_rot={n_rot} must be even and <= head_dim={head_dim}"
385 );
386 let s = shape::unary_shape(self.shape(x));
387 self.add_node(Op::Rope { head_dim, n_rot }, vec![x, cos, sin], s)
388 }
389
390 fn cast(&mut self, x: NodeId, to: DType) -> NodeId {
391 let s = shape::cast_shape(self.shape(x), to);
392 self.add_node(Op::Cast { to }, vec![x], s)
393 }
394
395 fn try_constant(&mut self, value: f64, dtype: DType) -> Result<NodeId, String> {
396 if matches!(dtype, DType::F16 | DType::BF16) {
397 let f32_id = self.try_constant(value, DType::F32)?;
398 return Ok(self.cast(f32_id, dtype));
399 }
400 let data = scalar_constant_bytes(value, dtype)?;
401 Ok(self.add_node(Op::Constant { data }, vec![], Shape::scalar(dtype)))
402 }
403
404 fn constant(&mut self, value: f64, dtype: DType) -> NodeId {
405 self.try_constant(value, dtype)
406 .expect("scalar constant encoding")
407 }
408
409 fn stop_gradient(&mut self, x: NodeId) -> NodeId {
410 let s = shape::unary_shape(self.shape(x));
411 self.add_node(Op::StopGradient, vec![x], s)
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418
419 #[test]
420 fn inferred_conv2d_and_conv_transpose2d() {
421 let mut g = Graph::new("conv");
422 let f = DType::F32;
423 let x = g.input("x", Shape::new(&[1, 4, 8, 8], f));
424 let w = g.param("w", Shape::new(&[8, 2, 3, 3], f));
425 let y = g.conv2d(x, w, [3, 3], [1, 1], [1, 1], [1, 1], 2);
426 assert_eq!(g.shape(y), &Shape::new(&[1, 8, 8, 8], f));
427
428 let wt = g.param("wt", Shape::new(&[4, 8, 2, 2], f));
429 let z = g.conv_transpose2d(x, wt, [2, 2], [2, 2], [0, 0], [1, 1], [0, 0], 1);
430 assert_eq!(g.shape(z), &Shape::new(&[1, 8, 16, 16], f));
431 }
432
433 #[test]
434 fn inferred_layer_norm2d() {
435 let mut g = Graph::new("ln2d");
436 let f = DType::F32;
437 let x = g.input("x", Shape::new(&[1, 4, 8, 8], f));
438 let gamma = g.param("g", Shape::new(&[4], f));
439 let beta = g.param("b", Shape::new(&[4], f));
440 let y = g.layer_norm2d(x, gamma, beta, 1e-6);
441 assert_eq!(g.shape(y), &Shape::new(&[1, 4, 8, 8], f));
442 }
443
444 #[test]
445 fn inferred_matmul_bias_gelu() {
446 let mut g = Graph::new("test");
447 let x = g.input("x", Shape::new(&[4, 15, 384], DType::F32));
448 let w = g.param("w", Shape::new(&[384, 1536], DType::F32));
449 let b = g.param("b", Shape::new(&[1536], DType::F32));
450
451 let mm = g.mm(x, w);
453 let add = g.add(mm, b);
454 let out = g.gelu(add);
455 g.set_outputs(vec![out]);
456
457 assert_eq!(g.shape(mm), &Shape::new(&[4, 15, 1536], DType::F32));
458 assert_eq!(g.shape(add), &Shape::new(&[4, 15, 1536], DType::F32));
459 assert_eq!(g.shape(out), &Shape::new(&[4, 15, 1536], DType::F32));
460 }
461
462 #[test]
463 fn inferred_bert_ffn() {
464 let mut g = Graph::new("bert_ffn");
465 let f = DType::F32;
466 let h = 384;
467 let int = 1536;
468
469 let x = g.input("x", Shape::new(&[4, 15, h], f));
470 let int_w = g.param("int.w", Shape::new(&[h, int], f));
471 let int_b = g.param("int.b", Shape::new(&[int], f));
472 let out_w = g.param("out.w", Shape::new(&[int, h], f));
473 let out_b = g.param("out.b", Shape::new(&[h], f));
474 let gamma = g.param("g", Shape::new(&[h], f));
475 let beta = g.param("b", Shape::new(&[h], f));
476
477 let mm1 = g.mm(x, int_w);
478 let a1 = g.add(mm1, int_b);
479 let ffn = g.gelu(a1);
480 let mm2 = g.mm(ffn, out_w);
481 let out = g.add(mm2, out_b);
482 let res = g.add(out, x);
483 let normed = g.ln(res, gamma, beta, 1e-12);
484 g.set_outputs(vec![normed]);
485
486 assert_eq!(g.shape(ffn), &Shape::new(&[4, 15, int], f));
487 assert_eq!(g.shape(out), &Shape::new(&[4, 15, h], f));
488 assert_eq!(g.shape(normed), &Shape::new(&[4, 15, h], f));
489 }
490
491 #[test]
492 fn inferred_gather_reshape() {
493 let mut g = Graph::new("test");
494 let table = g.param("emb", Shape::new(&[30522, 384], DType::F32));
495 let ids = g.input("ids", Shape::new(&[4, 15], DType::I64));
496
497 let gathered = g.gather_(table, ids, 0);
498 assert_eq!(g.shape(gathered), &Shape::new(&[4, 15, 384], DType::F32));
499
500 let reshaped = g.reshape_(gathered, vec![60, 384]);
501 assert_eq!(g.shape(reshaped), &Shape::new(&[60, 384], DType::F32));
502
503 let transposed = g.transpose_(reshaped, vec![1, 0]);
504 assert_eq!(g.shape(transposed), &Shape::new(&[384, 60], DType::F32));
505 }
506
507 #[test]
508 fn inferred_constant_broadcasts() {
509 let mut g = Graph::new("const");
510 let x = g.input("x", Shape::new(&[2, 3], DType::F32));
511 let c = g.constant(2.0, DType::F32);
512 assert_eq!(g.shape(c), &Shape::scalar(DType::F32));
513 let y = g.mul(x, c);
514 assert_eq!(g.shape(y), &Shape::new(&[2, 3], DType::F32));
515 }
516
517 #[test]
518 fn inferred_constant_f16_via_cast() {
519 let mut g = Graph::new("const_f16");
520 let c = g.constant(1.5, DType::F16);
521 assert_eq!(g.shape(c), &Shape::scalar(DType::F16));
522 let x = g.input("x", Shape::new(&[2], DType::F16));
523 let y = g.add(x, c);
524 assert_eq!(g.shape(y), &Shape::new(&[2], DType::F16));
525 }
526
527 #[test]
528 fn inferred_constant_arithmetic_chain() {
529 let mut g = Graph::new("const_chain");
530 let x = g.input("x", Shape::new(&[4], DType::F32));
531 let one = g.constant(1.0, DType::F32);
532 let two = g.constant(2.0, DType::F32);
533 let sum = g.add(x, one);
534 let y = g.div(sum, two);
535 assert_eq!(g.shape(y), &Shape::new(&[4], DType::F32));
536 g.set_outputs(vec![y]);
537 }
538
539 #[test]
540 fn try_constant_rejects_out_of_range() {
541 let mut g = Graph::new("try_const");
542 let err = g.try_constant(128.0, DType::I8).unwrap_err();
543 assert!(err.contains("out of range"));
544 }
545
546 #[test]
547 fn try_constant_f16_via_cast() {
548 let mut g = Graph::new("try_const_f16");
549 let c = g.try_constant(1.5, DType::F16).unwrap();
550 assert_eq!(g.shape(c), &Shape::scalar(DType::F16));
551 }
552
553 #[test]
554 fn inferred_reduce_softmax() {
555 let mut g = Graph::new("test");
556 let x = g.input("x", Shape::new(&[4, 15, 384], DType::F32));
557
558 let s = g.sm(x, -1);
559 assert_eq!(g.shape(s), &Shape::new(&[4, 15, 384], DType::F32));
560
561 let m = g.mean(x, vec![2], false);
562 assert_eq!(g.shape(m), &Shape::new(&[4, 15], DType::F32));
563
564 let mk = g.mean(x, vec![2], true);
565 assert_eq!(g.shape(mk), &Shape::new(&[4, 15, 1], DType::F32));
566 }
567}