1use yscv_autograd::Graph;
2use yscv_tensor::Tensor;
3
4use crate::{BatchNorm2dLayer, Conv2dLayer, ModelLayer, SequentialModel};
5
6pub fn fuse_conv_bn(conv: &Conv2dLayer, bn: &BatchNorm2dLayer) -> Conv2dLayer {
21 let out_channels = conv.out_channels();
22 assert_eq!(
23 out_channels,
24 bn.num_features(),
25 "Conv2d out_channels ({}) must match BatchNorm2d num_features ({})",
26 out_channels,
27 bn.num_features()
28 );
29
30 let gamma = bn.gamma().data();
31 let beta = bn.beta().data();
32 let running_mean = bn.running_mean().data();
33 let running_var = bn.running_var().data();
34 let eps = bn.epsilon();
35
36 let scale: Vec<f32> = (0..out_channels)
38 .map(|c| gamma[c] / (running_var[c] + eps).sqrt())
39 .collect();
40
41 let weight = conv.weight();
44 let w_data = weight.data();
45 let kh = conv.kernel_h();
46 let kw = conv.kernel_w();
47 let c_in = conv.in_channels();
48
49 let mut fused_w = vec![0.0f32; w_data.len()];
50 for i in 0..kh {
51 for j in 0..kw {
52 for ci in 0..c_in {
53 for co in 0..out_channels {
54 let idx = ((i * kw + j) * c_in + ci) * out_channels + co;
55 fused_w[idx] = w_data[idx] * scale[co];
56 }
57 }
58 }
59 }
60
61 let old_bias: Vec<f32> = match conv.bias() {
63 Some(b) => b.data().to_vec(),
64 None => vec![0.0; out_channels],
65 };
66
67 let fused_b: Vec<f32> = (0..out_channels)
68 .map(|c| scale[c] * (old_bias[c] - running_mean[c]) + beta[c])
69 .collect();
70
71 let fused_weight =
72 Tensor::from_vec(vec![kh, kw, c_in, out_channels], fused_w).expect("valid fused weight");
73 let fused_bias = Tensor::from_vec(vec![out_channels], fused_b).expect("valid fused bias");
74
75 Conv2dLayer::new(
76 c_in,
77 out_channels,
78 kh,
79 kw,
80 conv.stride_h(),
81 conv.stride_w(),
82 fused_weight,
83 Some(fused_bias),
84 )
85 .expect("fused Conv2dLayer construction should not fail")
86}
87
88pub fn optimize_sequential(model: &SequentialModel, graph: &mut Graph) -> SequentialModel {
94 let layers = model.layers();
95 let mut optimized = SequentialModel::new(graph);
96 let mut i = 0;
97
98 while i < layers.len() {
99 if i + 1 < layers.len()
100 && let (ModelLayer::Conv2d(conv), ModelLayer::BatchNorm2d(bn)) =
101 (&layers[i], &layers[i + 1])
102 {
103 let fused = fuse_conv_bn(conv, bn);
104 optimized
105 .add_conv2d(
106 fused.in_channels(),
107 fused.out_channels(),
108 fused.kernel_h(),
109 fused.kernel_w(),
110 fused.stride_h(),
111 fused.stride_w(),
112 fused.weight().clone(),
113 fused.bias().cloned(),
114 )
115 .expect("adding fused conv layer should not fail");
116 i += 2; continue;
118 }
119
120 push_layer(&mut optimized, graph, &layers[i]);
122 i += 1;
123 }
124
125 optimized
126}
127
128fn push_layer(model: &mut SequentialModel, graph: &mut Graph, layer: &ModelLayer) {
130 match layer {
131 ModelLayer::Conv2d(l) => {
132 model
133 .add_conv2d(
134 l.in_channels(),
135 l.out_channels(),
136 l.kernel_h(),
137 l.kernel_w(),
138 l.stride_h(),
139 l.stride_w(),
140 l.weight().clone(),
141 l.bias().cloned(),
142 )
143 .expect("add_conv2d");
144 }
145 ModelLayer::BatchNorm2d(l) => {
146 model
147 .add_batch_norm2d(
148 l.num_features(),
149 l.epsilon(),
150 l.gamma().clone(),
151 l.beta().clone(),
152 l.running_mean().clone(),
153 l.running_var().clone(),
154 )
155 .expect("add_batch_norm2d");
156 }
157 ModelLayer::ReLU(_) => model.add_relu(),
158 ModelLayer::LeakyReLU(l) => {
159 model
160 .add_leaky_relu(l.negative_slope())
161 .expect("add_leaky_relu");
162 }
163 ModelLayer::Sigmoid(_) => model.add_sigmoid(),
164 ModelLayer::Tanh(_) => model.add_tanh(),
165 ModelLayer::Dropout(l) => {
166 model.add_dropout(l.rate()).expect("add_dropout");
167 }
168 ModelLayer::Flatten(_) => model.add_flatten(),
169 ModelLayer::Softmax(_) => model.add_softmax(),
170 ModelLayer::GlobalAvgPool2d(_) => model.add_global_avg_pool2d(),
171 ModelLayer::MaxPool2d(l) => {
172 model
173 .add_max_pool2d(l.kernel_h(), l.kernel_w(), l.stride_h(), l.stride_w())
174 .expect("add_max_pool2d");
175 }
176 ModelLayer::AvgPool2d(l) => {
177 model
178 .add_avg_pool2d(l.kernel_h(), l.kernel_w(), l.stride_h(), l.stride_w())
179 .expect("add_avg_pool2d");
180 }
181 ModelLayer::Linear(l) => {
182 model
185 .add_linear_zero(graph, l.in_features(), l.out_features())
186 .expect("add_linear_zero");
187 }
188 ModelLayer::Embedding(l) => {
189 let weight = Tensor::zeros(vec![l.num_embeddings(), l.embedding_dim()])
190 .expect("embedding weight");
191 model
192 .add_embedding(graph, l.num_embeddings(), l.embedding_dim(), weight)
193 .expect("add_embedding");
194 }
195 ModelLayer::LayerNorm(l) => {
196 model
197 .add_layer_norm(graph, l.normalized_shape(), 1e-5)
198 .expect("add_layer_norm");
199 }
200 ModelLayer::GroupNorm(l) => {
201 model
202 .add_group_norm(graph, l.num_groups(), l.num_channels(), 1e-5)
203 .expect("add_group_norm");
204 }
205 ModelLayer::DepthwiseConv2d(l) => {
206 model
207 .add_depthwise_conv2d(
208 l.channels(),
209 l.kernel_h(),
210 l.kernel_w(),
211 l.stride_h(),
212 l.stride_w(),
213 l.weight().clone(),
214 l.bias().cloned(),
215 )
216 .expect("add_depthwise_conv2d");
217 }
218 ModelLayer::SeparableConv2d(l) => {
219 model
220 .add_separable_conv2d(
221 l.in_channels(),
222 l.out_channels(),
223 l.kernel_h(),
224 l.kernel_w(),
225 l.stride_h(),
226 l.stride_w(),
227 l.depthwise().weight().clone(),
228 l.pointwise().weight().clone(),
229 l.pointwise().bias().cloned(),
230 )
231 .expect("add_separable_conv2d");
232 }
233 ModelLayer::LoraLinear(l) => {
234 model
236 .add_linear_zero(graph, l.in_features, l.out_features)
237 .expect("add_linear_zero for lora");
238 }
239 ModelLayer::Conv1d(_)
241 | ModelLayer::Conv3d(_)
242 | ModelLayer::ConvTranspose2d(_)
243 | ModelLayer::AdaptiveAvgPool2d(_)
244 | ModelLayer::AdaptiveMaxPool2d(_)
245 | ModelLayer::InstanceNorm(_)
246 | ModelLayer::PixelShuffle(_)
247 | ModelLayer::Upsample(_)
248 | ModelLayer::GELU(_)
249 | ModelLayer::SiLU(_)
250 | ModelLayer::Mish(_)
251 | ModelLayer::PReLU(_)
252 | ModelLayer::ResidualBlock(_)
253 | ModelLayer::Rnn(_)
254 | ModelLayer::Lstm(_)
255 | ModelLayer::Gru(_)
256 | ModelLayer::MultiHeadAttention(_)
257 | ModelLayer::TransformerEncoder(_)
258 | ModelLayer::FeedForward(_)
259 | ModelLayer::DeformableConv2d(_) => {
260 model.push_raw_layer(layer.clone());
261 }
262 }
263}