1use crate::infer::GraphExt;
24use crate::op::Activation;
25use crate::{Graph, NodeId, Op, Shape};
26
27impl Graph {
28 pub fn linear_bias(&mut self, input: NodeId, weight: NodeId, bias: Option<NodeId>) -> NodeId {
30 let mm = self.mm(input, weight);
31 match bias {
32 Some(b) => self.add(mm, b),
33 None => mm,
34 }
35 }
36
37 pub fn linear_bias_act(
39 &mut self,
40 input: NodeId,
41 weight: NodeId,
42 bias: Option<NodeId>,
43 activation: Option<Activation>,
44 ) -> NodeId {
45 let x = self.linear_bias(input, weight, bias);
46 activation.map_or(x, |act| self.activation_by_kind(x, act))
47 }
48
49 pub fn linear_fused(
52 &mut self,
53 input: NodeId,
54 weight: NodeId,
55 bias: NodeId,
56 activation: Option<Activation>,
57 out_shape: Shape,
58 ) -> NodeId {
59 self.fused_matmul_bias_act(input, weight, bias, activation, out_shape)
60 }
61
62 pub fn shared_matmul_pair(
69 &mut self,
70 input: NodeId,
71 w_first: NodeId,
72 w_second: NodeId,
73 ) -> (NodeId, NodeId) {
74 let first = self.mm(input, w_first);
75 let second = self.mm(input, w_second);
76 (first, second)
77 }
78
79 pub fn swiglu_ffn(
84 &mut self,
85 input: NodeId,
86 up_w: NodeId,
87 gate_w: NodeId,
88 down_w: NodeId,
89 ) -> NodeId {
90 let (up, gate) = self.shared_matmul_pair(input, up_w, gate_w);
91 let gate_silu = self.silu(gate);
92 let hidden = self.mul(up, gate_silu);
93 self.mm(hidden, down_w)
94 }
95
96 pub fn fused_swiglu_ffn(
101 &mut self,
102 input: NodeId,
103 up_w: NodeId,
104 gate_w: NodeId,
105 down_w: NodeId,
106 out_shape: Shape,
107 ) -> NodeId {
108 let wu_shape = self.shape(up_w);
109 let wg_shape = self.shape(gate_w);
110 let k = wu_shape.dim(0).unwrap_static();
111 let n_up = wu_shape.dim(1).unwrap_static();
112 let n_gate = wg_shape.dim(1).unwrap_static();
113 debug_assert_eq!(wu_shape.dim(0), wg_shape.dim(0));
114
115 let concat_shape = Shape::new(&[k, n_up + n_gate], wu_shape.dtype());
116 let concat_w = self.concat(vec![up_w, gate_w], 1, concat_shape);
117
118 let input_shape = self.shape(input);
119 let out_rank = input_shape.rank();
120 let dtype = input_shape.dtype();
121 let mut cat_dims: Vec<usize> = (0..out_rank)
122 .map(|i| input_shape.dim(i).unwrap_static())
123 .collect();
124 cat_dims[out_rank - 1] = n_up + n_gate;
125 let cat_shape = Shape::new(&cat_dims, dtype);
126 let cat_mm = self.matmul(input, concat_w, cat_shape);
127
128 let mut hidden_dims = cat_dims;
129 hidden_dims[out_rank - 1] = n_up;
130 let hidden_shape = Shape::new(&hidden_dims, dtype);
131 let hidden = self.add_node(
132 Op::FusedSwiGLU {
133 cast_to: None,
134 gate_first: false,
135 },
136 vec![cat_mm],
137 hidden_shape,
138 );
139
140 let _ = out_shape;
141 self.mm(hidden, down_w)
142 }
143
144 fn activation_by_kind(&mut self, x: NodeId, act: Activation) -> NodeId {
145 match act {
146 Activation::Gelu => self.gelu(x),
147 Activation::GeluApprox => self.gelu_approx(x),
148 Activation::Silu => self.silu(x),
149 Activation::Relu => self.relu(x),
150 Activation::Exp => self.exp(x),
151 Activation::Sqrt => self.sqrt(x),
152 Activation::Neg => self.neg(x),
153 Activation::Tanh => self.tanh(x),
154 Activation::Sigmoid => {
155 let s = self.shape(x).clone();
156 self.activation(Activation::Sigmoid, x, s)
157 }
158 Activation::Log => {
159 let s = self.shape(x).clone();
160 self.activation(Activation::Log, x, s)
161 }
162 _ => {
163 let s = self.shape(x).clone();
164 self.activation(act, x, s)
165 }
166 }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::op::BinaryOp;
174 use crate::{DType, Op};
175
176 fn f32_shape(dims: &[usize]) -> Shape {
177 Shape::new(dims, DType::F32)
178 }
179
180 #[test]
181 fn linear_bias_act_emits_canonical_chain() {
182 let mut g = Graph::new("linear");
183 let x = g.input("x", f32_shape(&[4, 8]));
184 let w = g.param("w", f32_shape(&[8, 16]));
185 let b = g.param("b", f32_shape(&[16]));
186 let out = g.linear_bias_act(x, w, Some(b), Some(Activation::Silu));
187 g.set_outputs(vec![out]);
188
189 let act = g.node(out);
190 assert!(matches!(act.op, Op::Activation(Activation::Silu)));
191 let add = g.node(act.inputs[0]);
192 assert!(matches!(add.op, Op::Binary(BinaryOp::Add)));
193 let mm = g.node(add.inputs[0]);
194 assert!(matches!(mm.op, Op::MatMul));
195 }
196}