1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use crate::{
5 AvgPool2dLayer, BatchNorm2dLayer, Conv2dLayer, DeformableConv2dLayer, DepthwiseConv2dLayer,
6 DropoutLayer, EmbeddingLayer, FeedForwardLayer, FlattenLayer, GlobalAvgPool2dLayer,
7 GroupNormLayer, GruLayer, LayerCheckpoint, LayerNormLayer, LeakyReLULayer, LinearLayer,
8 LoraConfig, LoraLinear, LstmLayer, MaxPool2dLayer, ModelError, ModelLayer,
9 MultiHeadAttentionLayer, ReLULayer, ResidualBlock, RnnLayer, SeparableConv2dLayer,
10 SequentialCheckpoint, SigmoidLayer, SoftmaxLayer, TanhLayer, TensorSnapshot,
11 TransformerEncoderLayer, optimize_sequential,
12};
13
14#[derive(Debug, Clone)]
16pub struct SequentialModel {
17 layers: Vec<ModelLayer>,
18 frozen: Vec<bool>,
19 persistent_node_count: usize,
20 training: bool,
21}
22
23impl SequentialModel {
24 pub fn new(graph: &Graph) -> Self {
26 Self {
27 layers: Vec::new(),
28 frozen: Vec::new(),
29 persistent_node_count: graph.node_count(),
30 training: true,
31 }
32 }
33
34 pub fn add_linear(
35 &mut self,
36 graph: &mut Graph,
37 in_features: usize,
38 out_features: usize,
39 weight_init: Tensor,
40 bias_init: Tensor,
41 ) -> Result<(), ModelError> {
42 let layer = LinearLayer::new(graph, in_features, out_features, weight_init, bias_init)?;
43 self.layers.push(ModelLayer::Linear(layer));
44 self.frozen.push(false);
45 self.persistent_node_count = graph.node_count();
46 Ok(())
47 }
48
49 pub fn add_linear_zero(
50 &mut self,
51 graph: &mut Graph,
52 in_features: usize,
53 out_features: usize,
54 ) -> Result<(), ModelError> {
55 let layer = LinearLayer::zero_init(graph, in_features, out_features)?;
56 self.layers.push(ModelLayer::Linear(layer));
57 self.frozen.push(false);
58 self.persistent_node_count = graph.node_count();
59 Ok(())
60 }
61
62 pub fn add_relu(&mut self) {
63 self.layers.push(ModelLayer::ReLU(ReLULayer::new()));
64 self.frozen.push(false);
65 }
66
67 pub fn add_leaky_relu(&mut self, negative_slope: f32) -> Result<(), ModelError> {
68 let layer = LeakyReLULayer::new(negative_slope)?;
69 self.layers.push(ModelLayer::LeakyReLU(layer));
70 self.frozen.push(false);
71 Ok(())
72 }
73
74 pub fn add_sigmoid(&mut self) {
75 self.layers.push(ModelLayer::Sigmoid(SigmoidLayer::new()));
76 self.frozen.push(false);
77 }
78
79 pub fn add_tanh(&mut self) {
80 self.layers.push(ModelLayer::Tanh(TanhLayer::new()));
81 self.frozen.push(false);
82 }
83
84 pub fn add_gelu(&mut self) {
85 self.layers.push(ModelLayer::GELU(crate::GELULayer::new()));
86 self.frozen.push(false);
87 }
88
89 pub fn add_silu(&mut self) {
90 self.layers.push(ModelLayer::SiLU(crate::SiLULayer::new()));
91 self.frozen.push(false);
92 }
93
94 pub fn add_mish(&mut self) {
95 self.layers.push(ModelLayer::Mish(crate::MishLayer::new()));
96 self.frozen.push(false);
97 }
98
99 pub fn add_prelu(&mut self, alpha: Vec<f32>) {
100 self.layers
101 .push(ModelLayer::PReLU(crate::PReLULayer::new(alpha)));
102 self.frozen.push(false);
103 }
104
105 pub fn add_dropout(&mut self, rate: f32) -> Result<(), ModelError> {
106 let layer = DropoutLayer::new(rate)?;
107 self.layers.push(ModelLayer::Dropout(layer));
108 self.frozen.push(false);
109 Ok(())
110 }
111
112 #[allow(clippy::too_many_arguments)]
113 pub fn add_conv2d(
114 &mut self,
115 in_channels: usize,
116 out_channels: usize,
117 kernel_h: usize,
118 kernel_w: usize,
119 stride_h: usize,
120 stride_w: usize,
121 weight: Tensor,
122 bias: Option<Tensor>,
123 ) -> Result<(), ModelError> {
124 let layer = Conv2dLayer::new(
125 in_channels,
126 out_channels,
127 kernel_h,
128 kernel_w,
129 stride_h,
130 stride_w,
131 weight,
132 bias,
133 )?;
134 self.layers.push(ModelLayer::Conv2d(layer));
135 self.frozen.push(false);
136 Ok(())
137 }
138
139 #[allow(clippy::too_many_arguments)]
140 pub fn add_conv2d_zero(
141 &mut self,
142 in_channels: usize,
143 out_channels: usize,
144 kernel_h: usize,
145 kernel_w: usize,
146 stride_h: usize,
147 stride_w: usize,
148 use_bias: bool,
149 ) -> Result<(), ModelError> {
150 let layer = Conv2dLayer::zero_init(
151 in_channels,
152 out_channels,
153 kernel_h,
154 kernel_w,
155 stride_h,
156 stride_w,
157 use_bias,
158 )?;
159 self.layers.push(ModelLayer::Conv2d(layer));
160 self.frozen.push(false);
161 Ok(())
162 }
163
164 #[allow(clippy::too_many_arguments)]
165 pub fn add_deformable_conv2d(
166 &mut self,
167 in_channels: usize,
168 out_channels: usize,
169 kernel_h: usize,
170 kernel_w: usize,
171 stride: usize,
172 padding: usize,
173 weight: Tensor,
174 offset_weight: Tensor,
175 bias: Option<Tensor>,
176 ) -> Result<(), ModelError> {
177 let layer = DeformableConv2dLayer::new(
178 in_channels,
179 out_channels,
180 kernel_h,
181 kernel_w,
182 stride,
183 padding,
184 weight,
185 offset_weight,
186 bias,
187 )?;
188 self.layers.push(ModelLayer::DeformableConv2d(layer));
189 self.frozen.push(false);
190 Ok(())
191 }
192
193 pub fn add_deformable_conv2d_zero(
194 &mut self,
195 in_channels: usize,
196 out_channels: usize,
197 kernel_h: usize,
198 kernel_w: usize,
199 stride: usize,
200 padding: usize,
201 use_bias: bool,
202 ) -> Result<(), ModelError> {
203 let layer = DeformableConv2dLayer::zero_init(
204 in_channels,
205 out_channels,
206 kernel_h,
207 kernel_w,
208 stride,
209 padding,
210 use_bias,
211 )?;
212 self.layers.push(ModelLayer::DeformableConv2d(layer));
213 self.frozen.push(false);
214 Ok(())
215 }
216
217 #[allow(clippy::too_many_arguments)]
218 pub fn add_depthwise_conv2d(
219 &mut self,
220 channels: usize,
221 kernel_h: usize,
222 kernel_w: usize,
223 stride_h: usize,
224 stride_w: usize,
225 weight: Tensor,
226 bias: Option<Tensor>,
227 ) -> Result<(), ModelError> {
228 let layer = DepthwiseConv2dLayer::new(
229 channels, kernel_h, kernel_w, stride_h, stride_w, weight, bias,
230 )?;
231 self.layers.push(ModelLayer::DepthwiseConv2d(layer));
232 self.frozen.push(false);
233 Ok(())
234 }
235
236 pub fn add_depthwise_conv2d_zero(
237 &mut self,
238 channels: usize,
239 kernel_h: usize,
240 kernel_w: usize,
241 stride_h: usize,
242 stride_w: usize,
243 use_bias: bool,
244 ) -> Result<(), ModelError> {
245 let layer = DepthwiseConv2dLayer::zero_init(
246 channels, kernel_h, kernel_w, stride_h, stride_w, use_bias,
247 )?;
248 self.layers.push(ModelLayer::DepthwiseConv2d(layer));
249 self.frozen.push(false);
250 Ok(())
251 }
252
253 #[allow(clippy::too_many_arguments)]
254 pub fn add_separable_conv2d(
255 &mut self,
256 in_channels: usize,
257 out_channels: usize,
258 kernel_h: usize,
259 kernel_w: usize,
260 stride_h: usize,
261 stride_w: usize,
262 depthwise_weight: Tensor,
263 pointwise_weight: Tensor,
264 bias: Option<Tensor>,
265 ) -> Result<(), ModelError> {
266 let layer = SeparableConv2dLayer::new(
267 in_channels,
268 out_channels,
269 kernel_h,
270 kernel_w,
271 stride_h,
272 stride_w,
273 depthwise_weight,
274 pointwise_weight,
275 bias,
276 )?;
277 self.layers.push(ModelLayer::SeparableConv2d(layer));
278 self.frozen.push(false);
279 Ok(())
280 }
281
282 #[allow(clippy::too_many_arguments)]
283 pub fn add_separable_conv2d_zero(
284 &mut self,
285 in_channels: usize,
286 out_channels: usize,
287 kernel_h: usize,
288 kernel_w: usize,
289 stride_h: usize,
290 stride_w: usize,
291 use_bias: bool,
292 ) -> Result<(), ModelError> {
293 let layer = SeparableConv2dLayer::zero_init(
294 in_channels,
295 out_channels,
296 kernel_h,
297 kernel_w,
298 stride_h,
299 stride_w,
300 use_bias,
301 )?;
302 self.layers.push(ModelLayer::SeparableConv2d(layer));
303 self.frozen.push(false);
304 Ok(())
305 }
306
307 pub fn add_batch_norm2d(
308 &mut self,
309 num_features: usize,
310 epsilon: f32,
311 gamma: Tensor,
312 beta: Tensor,
313 running_mean: Tensor,
314 running_var: Tensor,
315 ) -> Result<(), ModelError> {
316 let layer = BatchNorm2dLayer::new(
317 num_features,
318 epsilon,
319 gamma,
320 beta,
321 running_mean,
322 running_var,
323 )?;
324 self.layers.push(ModelLayer::BatchNorm2d(layer));
325 self.frozen.push(false);
326 Ok(())
327 }
328
329 pub fn add_batch_norm2d_identity(
330 &mut self,
331 num_features: usize,
332 epsilon: f32,
333 ) -> Result<(), ModelError> {
334 let layer = BatchNorm2dLayer::identity_init(num_features, epsilon)?;
335 self.layers.push(ModelLayer::BatchNorm2d(layer));
336 self.frozen.push(false);
337 Ok(())
338 }
339
340 pub fn add_max_pool2d(
341 &mut self,
342 kernel_h: usize,
343 kernel_w: usize,
344 stride_h: usize,
345 stride_w: usize,
346 ) -> Result<(), ModelError> {
347 let layer = MaxPool2dLayer::new(kernel_h, kernel_w, stride_h, stride_w)?;
348 self.layers.push(ModelLayer::MaxPool2d(layer));
349 self.frozen.push(false);
350 Ok(())
351 }
352
353 pub fn add_avg_pool2d(
354 &mut self,
355 kernel_h: usize,
356 kernel_w: usize,
357 stride_h: usize,
358 stride_w: usize,
359 ) -> Result<(), ModelError> {
360 let layer = AvgPool2dLayer::new(kernel_h, kernel_w, stride_h, stride_w)?;
361 self.layers.push(ModelLayer::AvgPool2d(layer));
362 self.frozen.push(false);
363 Ok(())
364 }
365
366 pub fn add_flatten(&mut self) {
367 self.layers.push(ModelLayer::Flatten(FlattenLayer::new()));
368 self.frozen.push(false);
369 }
370
371 pub fn add_global_avg_pool2d(&mut self) {
372 self.layers
373 .push(ModelLayer::GlobalAvgPool2d(GlobalAvgPool2dLayer::new()));
374 self.frozen.push(false);
375 }
376
377 pub fn add_softmax(&mut self) {
378 self.layers.push(ModelLayer::Softmax(SoftmaxLayer::new()));
379 self.frozen.push(false);
380 }
381
382 pub fn add_embedding(
383 &mut self,
384 graph: &mut Graph,
385 num_embeddings: usize,
386 embedding_dim: usize,
387 weight_init: Tensor,
388 ) -> Result<(), ModelError> {
389 let layer = EmbeddingLayer::new(graph, num_embeddings, embedding_dim, weight_init)?;
390 self.layers.push(ModelLayer::Embedding(layer));
391 self.frozen.push(false);
392 Ok(())
393 }
394
395 pub fn add_layer_norm(
396 &mut self,
397 graph: &mut Graph,
398 normalized_shape: usize,
399 eps: f32,
400 ) -> Result<(), ModelError> {
401 let layer = LayerNormLayer::new(graph, normalized_shape, eps)?;
402 self.layers.push(ModelLayer::LayerNorm(layer));
403 self.frozen.push(false);
404 Ok(())
405 }
406
407 pub fn add_group_norm(
408 &mut self,
409 graph: &mut Graph,
410 num_groups: usize,
411 num_channels: usize,
412 eps: f32,
413 ) -> Result<(), ModelError> {
414 let layer = GroupNormLayer::new(graph, num_groups, num_channels, eps)?;
415 self.layers.push(ModelLayer::GroupNorm(layer));
416 self.frozen.push(false);
417 Ok(())
418 }
419
420 pub fn apply_lora(
424 &mut self,
425 graph: &mut Graph,
426 config: &LoraConfig,
427 ) -> Result<usize, ModelError> {
428 let mut count = 0;
429 for layer in self.layers.iter_mut() {
430 if let ModelLayer::Linear(linear) = layer {
431 let in_features = linear.in_features();
432 let out_features = linear.out_features();
433 let weight_node = linear.weight_node().expect("linear layer has weight node");
434 let bias_node = linear.bias_node().expect("linear layer has bias node");
435 let lora = LoraLinear::from_linear(
436 graph,
437 weight_node,
438 bias_node,
439 in_features,
440 out_features,
441 config,
442 )?;
443 *layer = ModelLayer::LoraLinear(lora);
444 count += 1;
445 }
446 }
447 self.persistent_node_count = graph.node_count();
448 Ok(count)
449 }
450
451 pub fn merge_lora(&mut self, graph: &mut Graph) -> Result<usize, ModelError> {
455 let mut count = 0;
456 for layer in self.layers.iter_mut() {
457 if let ModelLayer::LoraLinear(lora) = layer {
458 let merged_weight = lora.merge(graph)?;
459 let in_features = lora.in_features;
460 let out_features = lora.out_features;
461 let bias_tensor = if let Some(bias_node) = lora.bias {
462 graph.value(bias_node)?.clone()
463 } else {
464 Tensor::zeros(vec![out_features])?
465 };
466 let linear =
467 LinearLayer::new(graph, in_features, out_features, merged_weight, bias_tensor)?;
468 *layer = ModelLayer::Linear(linear);
469 count += 1;
470 }
471 }
472 self.persistent_node_count = graph.node_count();
473 Ok(count)
474 }
475
476 pub fn optimize(&mut self, graph: &mut Graph) -> usize {
481 let before = self.layers.len();
482 let optimized = optimize_sequential(self, graph);
483 let after = optimized.layers.len();
484 *self = optimized;
485 before - after
486 }
487
488 pub fn set_training(&mut self, training: bool) {
490 self.training = training;
491 for layer in &mut self.layers {
492 if let ModelLayer::Dropout(d) = layer {
493 d.set_training(training);
494 }
495 }
496 }
497
498 pub fn eval(&mut self) {
500 self.set_training(false);
501 }
502
503 pub fn train_mode(&mut self) {
505 self.set_training(true);
506 }
507
508 pub fn is_training(&self) -> bool {
510 self.training
511 }
512
513 pub fn summary(&self) -> String {
518 use std::fmt::Write;
519
520 let mut out = String::new();
521 let sep = "-".repeat(65);
522 writeln!(out, "{sep}").expect("write to String");
523 writeln!(
524 out,
525 " {:>3} {:<25} {:>20} {:>8}",
526 "#", "Layer", "Details", "Params"
527 )
528 .expect("write to String");
529 writeln!(out, "{sep}").expect("write to String");
530
531 let mut total_params = 0usize;
532
533 for (i, layer) in self.layers.iter().enumerate() {
534 let (name, details, params) = match layer {
535 ModelLayer::Linear(l) => {
536 let p = l.in_features() * l.out_features() + l.out_features();
537 (
538 "Linear".to_string(),
539 format!("{}→{}", l.in_features(), l.out_features()),
540 p,
541 )
542 }
543 ModelLayer::Conv2d(l) => {
544 let p = l.out_channels() * l.in_channels() * l.kernel_h() * l.kernel_w()
545 + if l.bias().is_some() {
546 l.out_channels()
547 } else {
548 0
549 };
550 (
551 "Conv2d".to_string(),
552 format!(
553 "{}→{} {}x{} s{}",
554 l.in_channels(),
555 l.out_channels(),
556 l.kernel_h(),
557 l.kernel_w(),
558 l.stride_h(),
559 ),
560 p,
561 )
562 }
563 ModelLayer::BatchNorm2d(l) => {
564 let p = l.num_features() * 2; (
566 "BatchNorm2d".to_string(),
567 format!("{}", l.num_features()),
568 p,
569 )
570 }
571 ModelLayer::MaxPool2d(l) => (
572 "MaxPool2d".to_string(),
573 format!("{}x{} s{}", l.kernel_h(), l.kernel_w(), l.stride_h()),
574 0,
575 ),
576 ModelLayer::AvgPool2d(l) => (
577 "AvgPool2d".to_string(),
578 format!("{}x{} s{}", l.kernel_h(), l.kernel_w(), l.stride_h()),
579 0,
580 ),
581 ModelLayer::GlobalAvgPool2d(_) => ("GlobalAvgPool2d".to_string(), String::new(), 0),
582 ModelLayer::Flatten(_) => ("Flatten".to_string(), String::new(), 0),
583 ModelLayer::ReLU(_) => ("ReLU".to_string(), String::new(), 0),
584 ModelLayer::LeakyReLU(_) => ("LeakyReLU".to_string(), String::new(), 0),
585 ModelLayer::Sigmoid(_) => ("Sigmoid".to_string(), String::new(), 0),
586 ModelLayer::Tanh(_) => ("Tanh".to_string(), String::new(), 0),
587 ModelLayer::Softmax(_) => ("Softmax".to_string(), String::new(), 0),
588 ModelLayer::Dropout(d) => ("Dropout".to_string(), format!("p={:.2}", d.rate()), 0),
589 ModelLayer::Embedding(e) => {
590 let p = e.num_embeddings() * e.embedding_dim();
591 (
592 "Embedding".to_string(),
593 format!("{}x{}", e.num_embeddings(), e.embedding_dim()),
594 p,
595 )
596 }
597 ModelLayer::LayerNorm(_) => ("LayerNorm".to_string(), String::new(), 0),
598 ModelLayer::GroupNorm(_) => ("GroupNorm".to_string(), String::new(), 0),
599 ModelLayer::DepthwiseConv2d(_) => ("DepthwiseConv2d".to_string(), String::new(), 0),
600 ModelLayer::SeparableConv2d(_) => ("SeparableConv2d".to_string(), String::new(), 0),
601 ModelLayer::DeformableConv2d(l) => {
602 let p = l.weight.data().len() + l.offset_weight.data().len();
603 (
604 "DeformableConv2d".to_string(),
605 format!(
606 "{}→{} {}x{}",
607 l.in_channels(),
608 l.out_channels(),
609 l.kernel_h,
610 l.kernel_w
611 ),
612 p,
613 )
614 }
615 ModelLayer::LoraLinear(l) => (
616 "LoraLinear".to_string(),
617 format!("{}→{} r={}", l.in_features, l.out_features, l.rank),
618 l.in_features * l.rank + l.rank * l.out_features,
619 ),
620 ModelLayer::Conv1d(l) => {
621 let p = l.kernel().data().len();
622 ("Conv1d".to_string(), format!("k={}", l.kernel_size()), p)
623 }
624 ModelLayer::Conv3d(l) => {
625 let p = l.weight().data().len();
626 (
627 "Conv3d".to_string(),
628 format!("{}→{}", l.in_channels(), l.out_channels()),
629 p,
630 )
631 }
632 ModelLayer::ConvTranspose2d(l) => {
633 let p = l.kernel().data().len();
634 (
635 "ConvTranspose2d".to_string(),
636 format!("s={}", l.stride()),
637 p,
638 )
639 }
640 ModelLayer::AdaptiveAvgPool2d(l) => (
641 "AdaptiveAvgPool2d".to_string(),
642 format!("{}x{}", l.output_h(), l.output_w()),
643 0,
644 ),
645 ModelLayer::AdaptiveMaxPool2d(l) => (
646 "AdaptiveMaxPool2d".to_string(),
647 format!("{}x{}", l.output_h(), l.output_w()),
648 0,
649 ),
650 ModelLayer::InstanceNorm(_) => ("InstanceNorm".to_string(), String::new(), 0),
651 ModelLayer::PixelShuffle(l) => (
652 "PixelShuffle".to_string(),
653 format!("r={}", l.upscale_factor()),
654 0,
655 ),
656 ModelLayer::Upsample(l) => {
657 ("Upsample".to_string(), format!("{}x", l.scale_factor()), 0)
658 }
659 ModelLayer::GELU(_) => ("GELU".to_string(), String::new(), 0),
660 ModelLayer::SiLU(_) => ("SiLU".to_string(), String::new(), 0),
661 ModelLayer::Mish(_) => ("Mish".to_string(), String::new(), 0),
662 ModelLayer::PReLU(l) => (
663 "PReLU".to_string(),
664 format!("channels={}", l.alpha().len()),
665 l.alpha().len(),
666 ),
667 ModelLayer::ResidualBlock(r) => (
668 "ResidualBlock".to_string(),
669 format!("{} layers", r.layers().len()),
670 0,
671 ),
672 ModelLayer::Rnn(l) => {
673 let p = l.input_size * l.hidden_size
674 + l.hidden_size * l.hidden_size
675 + l.hidden_size;
676 (
677 "Rnn".to_string(),
678 format!("{}→{}", l.input_size, l.hidden_size),
679 p,
680 )
681 }
682 ModelLayer::Lstm(l) => {
683 let h4 = 4 * l.hidden_size;
684 let p = l.input_size * h4 + l.hidden_size * h4 + h4;
685 (
686 "Lstm".to_string(),
687 format!("{}→{}", l.input_size, l.hidden_size),
688 p,
689 )
690 }
691 ModelLayer::Gru(l) => {
692 let h3 = 3 * l.hidden_size;
693 let p = l.input_size * h3 + l.hidden_size * h3 + 2 * h3;
694 (
695 "Gru".to_string(),
696 format!("{}→{}", l.input_size, l.hidden_size),
697 p,
698 )
699 }
700 ModelLayer::MultiHeadAttention(_) => {
701 ("MultiHeadAttention".to_string(), String::new(), 0)
702 }
703 ModelLayer::TransformerEncoder(_) => {
704 ("TransformerEncoder".to_string(), String::new(), 0)
705 }
706 ModelLayer::FeedForward(_) => ("FeedForward".to_string(), String::new(), 0),
707 };
708
709 total_params += params;
710 let params_str = if params > 0 {
711 format_param_count(params)
712 } else {
713 "-".to_string()
714 };
715 writeln!(
716 out,
717 " {:>3} {:<25} {:>20} {:>8}",
718 i, name, details, params_str
719 )
720 .expect("write to String");
721 }
722
723 writeln!(out, "{sep}").expect("write to String");
724 writeln!(
725 out,
726 " Total: {} layers, {} parameters",
727 self.layers.len(),
728 format_param_count(total_params)
729 )
730 .expect("write to String");
731 writeln!(out, "{sep}").expect("write to String");
732 out
733 }
734
735 pub fn num_parameters(&self) -> usize {
737 let mut total = 0usize;
738 for layer in &self.layers {
739 let params = match layer {
740 ModelLayer::Linear(l) => l.in_features() * l.out_features() + l.out_features(),
741 ModelLayer::Conv2d(l) => {
742 l.out_channels() * l.in_channels() * l.kernel_h() * l.kernel_w()
743 + if l.bias().is_some() {
744 l.out_channels()
745 } else {
746 0
747 }
748 }
749 ModelLayer::BatchNorm2d(l) => l.num_features() * 2,
750 ModelLayer::Embedding(e) => e.num_embeddings() * e.embedding_dim(),
751 ModelLayer::LoraLinear(l) => l.in_features * l.rank + l.rank * l.out_features,
752 ModelLayer::Conv1d(l) => l.kernel().data().len(),
753 ModelLayer::Conv3d(l) => l.weight().data().len(),
754 ModelLayer::ConvTranspose2d(l) => l.kernel().data().len(),
755 ModelLayer::DepthwiseConv2d(l) => {
756 l.weight().data().len() + l.bias().map_or(0, |b| b.data().len())
757 }
758 ModelLayer::SeparableConv2d(l) => {
759 l.depthwise().weight().data().len()
760 + l.depthwise().bias().map_or(0, |b| b.data().len())
761 + l.pointwise().weight().data().len()
762 + l.pointwise().bias().map_or(0, |b| b.data().len())
763 }
764 ModelLayer::DeformableConv2d(l) => {
765 l.weight.data().len()
766 + l.offset_weight.data().len()
767 + l.bias.as_ref().map_or(0, |b| b.data().len())
768 }
769 ModelLayer::InstanceNorm(_) => 0,
770 ModelLayer::LayerNorm(_)
771 | ModelLayer::GroupNorm(_)
772 | ModelLayer::MaxPool2d(_)
773 | ModelLayer::AvgPool2d(_)
774 | ModelLayer::GlobalAvgPool2d(_)
775 | ModelLayer::Flatten(_)
776 | ModelLayer::ReLU(_)
777 | ModelLayer::LeakyReLU(_)
778 | ModelLayer::Sigmoid(_)
779 | ModelLayer::Tanh(_)
780 | ModelLayer::Softmax(_)
781 | ModelLayer::Dropout(_)
782 | ModelLayer::AdaptiveAvgPool2d(_)
783 | ModelLayer::AdaptiveMaxPool2d(_)
784 | ModelLayer::PixelShuffle(_)
785 | ModelLayer::Upsample(_)
786 | ModelLayer::GELU(_)
787 | ModelLayer::SiLU(_)
788 | ModelLayer::Mish(_) => 0,
789 ModelLayer::PReLU(l) => l.alpha().len(),
790 ModelLayer::ResidualBlock(_) => 0,
791 ModelLayer::Rnn(l) => {
792 l.input_size * l.hidden_size + l.hidden_size * l.hidden_size + l.hidden_size
793 }
794 ModelLayer::Lstm(l) => {
795 let h4 = 4 * l.hidden_size;
796 l.input_size * h4 + l.hidden_size * h4 + h4
797 }
798 ModelLayer::Gru(l) => {
799 let h3 = 3 * l.hidden_size;
800 l.input_size * h3 + l.hidden_size * h3 + 2 * h3
801 }
802 ModelLayer::MultiHeadAttention(_)
803 | ModelLayer::TransformerEncoder(_)
804 | ModelLayer::FeedForward(_) => 0,
805 };
806 total += params;
807 }
808 total
809 }
810
811 pub fn freeze_layer(&mut self, index: usize) -> Result<(), ModelError> {
813 if index >= self.layers.len() {
814 return Err(ModelError::InvalidLayerIndex {
815 index,
816 count: self.layers.len(),
817 });
818 }
819 self.frozen[index] = true;
820 Ok(())
821 }
822
823 pub fn unfreeze_layer(&mut self, index: usize) -> Result<(), ModelError> {
825 if index >= self.layers.len() {
826 return Err(ModelError::InvalidLayerIndex {
827 index,
828 count: self.layers.len(),
829 });
830 }
831 self.frozen[index] = false;
832 Ok(())
833 }
834
835 pub fn frozen_mask(&self) -> &[bool] {
837 &self.frozen
838 }
839
840 pub fn trainable_parameters(&self) -> usize {
842 let mut total = 0usize;
843 for (i, layer) in self.layers.iter().enumerate() {
844 if self.frozen[i] {
845 continue;
846 }
847 let params = match layer {
848 ModelLayer::Linear(l) => l.in_features() * l.out_features() + l.out_features(),
849 ModelLayer::Conv2d(l) => {
850 l.out_channels() * l.in_channels() * l.kernel_h() * l.kernel_w()
851 + if l.bias().is_some() {
852 l.out_channels()
853 } else {
854 0
855 }
856 }
857 ModelLayer::BatchNorm2d(l) => l.num_features() * 2,
858 ModelLayer::Embedding(e) => e.num_embeddings() * e.embedding_dim(),
859 ModelLayer::LoraLinear(l) => l.in_features * l.rank + l.rank * l.out_features,
860 ModelLayer::Conv1d(l) => l.kernel().data().len(),
861 ModelLayer::Conv3d(l) => l.weight().data().len(),
862 ModelLayer::ConvTranspose2d(l) => l.kernel().data().len(),
863 ModelLayer::DepthwiseConv2d(l) => {
864 l.weight().data().len() + l.bias().map_or(0, |b| b.data().len())
865 }
866 ModelLayer::SeparableConv2d(l) => {
867 l.depthwise().weight().data().len()
868 + l.depthwise().bias().map_or(0, |b| b.data().len())
869 + l.pointwise().weight().data().len()
870 + l.pointwise().bias().map_or(0, |b| b.data().len())
871 }
872 ModelLayer::DeformableConv2d(l) => {
873 l.weight.data().len()
874 + l.offset_weight.data().len()
875 + l.bias.as_ref().map_or(0, |b| b.data().len())
876 }
877 ModelLayer::InstanceNorm(_) => 0,
878 ModelLayer::LayerNorm(_)
879 | ModelLayer::GroupNorm(_)
880 | ModelLayer::MaxPool2d(_)
881 | ModelLayer::AvgPool2d(_)
882 | ModelLayer::GlobalAvgPool2d(_)
883 | ModelLayer::Flatten(_)
884 | ModelLayer::ReLU(_)
885 | ModelLayer::LeakyReLU(_)
886 | ModelLayer::Sigmoid(_)
887 | ModelLayer::Tanh(_)
888 | ModelLayer::Softmax(_)
889 | ModelLayer::Dropout(_)
890 | ModelLayer::AdaptiveAvgPool2d(_)
891 | ModelLayer::AdaptiveMaxPool2d(_)
892 | ModelLayer::PixelShuffle(_)
893 | ModelLayer::Upsample(_)
894 | ModelLayer::GELU(_)
895 | ModelLayer::SiLU(_)
896 | ModelLayer::Mish(_) => 0,
897 ModelLayer::PReLU(l) => l.alpha().len(),
898 ModelLayer::ResidualBlock(_) => 0,
899 ModelLayer::Rnn(l) => {
900 l.input_size * l.hidden_size + l.hidden_size * l.hidden_size + l.hidden_size
901 }
902 ModelLayer::Lstm(l) => {
903 let h4 = 4 * l.hidden_size;
904 l.input_size * h4 + l.hidden_size * h4 + h4
905 }
906 ModelLayer::Gru(l) => {
907 let h3 = 3 * l.hidden_size;
908 l.input_size * h3 + l.hidden_size * h3 + 2 * h3
909 }
910 ModelLayer::MultiHeadAttention(_)
911 | ModelLayer::TransformerEncoder(_)
912 | ModelLayer::FeedForward(_) => 0,
913 };
914 total += params;
915 }
916 total
917 }
918
919 pub fn named_parameters<'a>(
929 &'a self,
930 graph: &'a Graph,
931 ) -> Result<Vec<(String, &'a Tensor)>, ModelError> {
932 let mut result = Vec::new();
933 let mut type_counts = std::collections::HashMap::<&str, usize>::new();
934
935 for layer in &self.layers {
936 match layer {
937 ModelLayer::Linear(l) => {
938 let idx = type_counts.entry("linear").or_insert(0);
939 let i = *idx;
940 *idx += 1;
941 result.push((
942 format!("linear{i}_weight"),
943 graph.value(l.weight_node().expect("linear layer has weight node"))?,
944 ));
945 result.push((
946 format!("linear{i}_bias"),
947 graph.value(l.bias_node().expect("linear layer has bias node"))?,
948 ));
949 }
950 ModelLayer::Conv2d(l) => {
951 let idx = type_counts.entry("conv2d").or_insert(0);
952 let i = *idx;
953 *idx += 1;
954 result.push((format!("conv2d{i}_weight"), l.weight()));
955 if let Some(b) = l.bias() {
956 result.push((format!("conv2d{i}_bias"), b));
957 }
958 }
959 ModelLayer::BatchNorm2d(l) => {
960 let idx = type_counts.entry("batchnorm2d").or_insert(0);
961 let i = *idx;
962 *idx += 1;
963 result.push((format!("batchnorm2d{i}_gamma"), l.gamma()));
964 result.push((format!("batchnorm2d{i}_beta"), l.beta()));
965 }
966 ModelLayer::Embedding(e) => {
967 let idx = type_counts.entry("embedding").or_insert(0);
968 let i = *idx;
969 *idx += 1;
970 result.push((
971 format!("embedding{i}_weight"),
972 graph.value(e.weight_node())?,
973 ));
974 }
975 ModelLayer::LayerNorm(l) => {
976 let idx = type_counts.entry("layernorm").or_insert(0);
977 let i = *idx;
978 *idx += 1;
979 result.push((format!("layernorm{i}_gamma"), graph.value(l.gamma_node())?));
980 result.push((format!("layernorm{i}_beta"), graph.value(l.beta_node())?));
981 }
982 ModelLayer::GroupNorm(l) => {
983 let idx = type_counts.entry("groupnorm").or_insert(0);
984 let i = *idx;
985 *idx += 1;
986 result.push((format!("groupnorm{i}_gamma"), graph.value(l.gamma_node())?));
987 result.push((format!("groupnorm{i}_beta"), graph.value(l.beta_node())?));
988 }
989 ModelLayer::LoraLinear(l) => {
990 let idx = type_counts.entry("loralinear").or_insert(0);
991 let i = *idx;
992 *idx += 1;
993 result.push((
994 format!("loralinear{i}_frozen_weight"),
995 graph.value(l.frozen_weight)?,
996 ));
997 result.push((format!("loralinear{i}_lora_a"), graph.value(l.lora_a)?));
998 result.push((format!("loralinear{i}_lora_b"), graph.value(l.lora_b)?));
999 if let Some(bias_node) = l.bias {
1000 result.push((format!("loralinear{i}_bias"), graph.value(bias_node)?));
1001 }
1002 }
1003 ModelLayer::DepthwiseConv2d(l) => {
1004 let idx = type_counts.entry("depthwiseconv2d").or_insert(0);
1005 let i = *idx;
1006 *idx += 1;
1007 result.push((format!("depthwiseconv2d{i}_weight"), l.weight()));
1008 if let Some(b) = l.bias() {
1009 result.push((format!("depthwiseconv2d{i}_bias"), b));
1010 }
1011 }
1012 ModelLayer::SeparableConv2d(l) => {
1013 let idx = type_counts.entry("separableconv2d").or_insert(0);
1014 let i = *idx;
1015 *idx += 1;
1016 result.push((
1017 format!("separableconv2d{i}_depthwise_weight"),
1018 l.depthwise().weight(),
1019 ));
1020 if let Some(b) = l.depthwise().bias() {
1021 result.push((format!("separableconv2d{i}_depthwise_bias"), b));
1022 }
1023 result.push((
1024 format!("separableconv2d{i}_pointwise_weight"),
1025 l.pointwise().weight(),
1026 ));
1027 if let Some(b) = l.pointwise().bias() {
1028 result.push((format!("separableconv2d{i}_pointwise_bias"), b));
1029 }
1030 }
1031 ModelLayer::Conv1d(l) => {
1032 let idx = type_counts.entry("conv1d").or_insert(0);
1033 let i = *idx;
1034 *idx += 1;
1035 result.push((format!("conv1d{i}_weight"), l.kernel()));
1036 }
1037 ModelLayer::Conv3d(l) => {
1038 let idx = type_counts.entry("conv3d").or_insert(0);
1039 let i = *idx;
1040 *idx += 1;
1041 result.push((format!("conv3d{i}_weight"), l.weight()));
1042 }
1043 ModelLayer::ConvTranspose2d(l) => {
1044 let idx = type_counts.entry("convtranspose2d").or_insert(0);
1045 let i = *idx;
1046 *idx += 1;
1047 result.push((format!("convtranspose2d{i}_weight"), l.kernel()));
1048 }
1049 ModelLayer::ReLU(_)
1051 | ModelLayer::LeakyReLU(_)
1052 | ModelLayer::Sigmoid(_)
1053 | ModelLayer::Tanh(_)
1054 | ModelLayer::Softmax(_)
1055 | ModelLayer::Dropout(_)
1056 | ModelLayer::MaxPool2d(_)
1057 | ModelLayer::AvgPool2d(_)
1058 | ModelLayer::GlobalAvgPool2d(_)
1059 | ModelLayer::Flatten(_)
1060 | ModelLayer::AdaptiveAvgPool2d(_)
1061 | ModelLayer::AdaptiveMaxPool2d(_)
1062 | ModelLayer::InstanceNorm(_)
1063 | ModelLayer::PixelShuffle(_)
1064 | ModelLayer::Upsample(_)
1065 | ModelLayer::GELU(_)
1066 | ModelLayer::SiLU(_)
1067 | ModelLayer::Mish(_)
1068 | ModelLayer::PReLU(_)
1069 | ModelLayer::ResidualBlock(_)
1070 | ModelLayer::Rnn(_)
1071 | ModelLayer::Lstm(_)
1072 | ModelLayer::Gru(_)
1073 | ModelLayer::MultiHeadAttention(_)
1074 | ModelLayer::TransformerEncoder(_)
1075 | ModelLayer::FeedForward(_)
1076 | ModelLayer::DeformableConv2d(_) => {}
1077 }
1078 }
1079 Ok(result)
1080 }
1081
1082 pub fn layers(&self) -> &[ModelLayer] {
1083 &self.layers
1084 }
1085
1086 pub fn layers_mut(&mut self) -> &mut [ModelLayer] {
1087 &mut self.layers
1088 }
1089
1090 pub fn add_residual_block(&mut self, layers: Vec<ModelLayer>) {
1092 self.layers
1093 .push(ModelLayer::ResidualBlock(ResidualBlock::new(layers)));
1094 self.frozen.push(false);
1095 }
1096
1097 pub fn add_rnn(&mut self, input_size: usize, hidden_size: usize, seed: u64) {
1098 self.layers.push(ModelLayer::Rnn(RnnLayer::new(
1099 input_size,
1100 hidden_size,
1101 seed,
1102 )));
1103 self.frozen.push(false);
1104 }
1105
1106 pub fn add_lstm(&mut self, input_size: usize, hidden_size: usize, seed: u64) {
1107 self.layers.push(ModelLayer::Lstm(LstmLayer::new(
1108 input_size,
1109 hidden_size,
1110 seed,
1111 )));
1112 self.frozen.push(false);
1113 }
1114
1115 pub fn add_gru(&mut self, input_size: usize, hidden_size: usize, seed: u64) {
1116 self.layers.push(ModelLayer::Gru(GruLayer::new(
1117 input_size,
1118 hidden_size,
1119 seed,
1120 )));
1121 self.frozen.push(false);
1122 }
1123
1124 pub fn add_multi_head_attention(&mut self, d_model: usize, num_heads: usize, seed: u64) {
1125 self.layers.push(ModelLayer::MultiHeadAttention(
1126 MultiHeadAttentionLayer::new(d_model, num_heads, seed),
1127 ));
1128 self.frozen.push(false);
1129 }
1130
1131 pub fn add_transformer_encoder(
1132 &mut self,
1133 d_model: usize,
1134 num_heads: usize,
1135 d_ff: usize,
1136 seed: u64,
1137 ) {
1138 self.layers.push(ModelLayer::TransformerEncoder(
1139 TransformerEncoderLayer::new(d_model, num_heads, d_ff, seed),
1140 ));
1141 self.frozen.push(false);
1142 }
1143
1144 pub fn add_feed_forward(&mut self, d_model: usize, d_ff: usize, seed: u64) {
1145 self.layers
1146 .push(ModelLayer::FeedForward(FeedForwardLayer::new(
1147 d_model, d_ff, seed,
1148 )));
1149 self.frozen.push(false);
1150 }
1151
1152 pub fn push_raw_layer(&mut self, layer: ModelLayer) {
1154 self.layers.push(layer);
1155 self.frozen.push(false);
1156 }
1157
1158 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
1159 let mut current = input;
1160 for layer in &self.layers {
1161 current = layer.forward(graph, current)?;
1162 }
1163 Ok(current)
1164 }
1165
1166 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
1171 let mut current = input.clone();
1172 for layer in &self.layers {
1173 current = match layer {
1174 ModelLayer::Conv2d(l) => l.forward_inference(¤t)?,
1175 ModelLayer::BatchNorm2d(l) => l.forward_inference(¤t)?,
1176 ModelLayer::MaxPool2d(l) => l.forward_inference(¤t)?,
1177 ModelLayer::AvgPool2d(l) => l.forward_inference(¤t)?,
1178 ModelLayer::Flatten(l) => l.forward_inference(¤t)?,
1179 ModelLayer::Softmax(l) => l.forward_inference(¤t)?,
1180 ModelLayer::ReLU(_) => {
1181 let out_data: Vec<f32> = current.data().iter().map(|&v| v.max(0.0)).collect();
1182 Tensor::from_vec(current.shape().to_vec(), out_data)?
1183 }
1184 ModelLayer::Sigmoid(_) => {
1185 let out_data: Vec<f32> = current
1186 .data()
1187 .iter()
1188 .map(|&v| 1.0 / (1.0 + (-v).exp()))
1189 .collect();
1190 Tensor::from_vec(current.shape().to_vec(), out_data)?
1191 }
1192 ModelLayer::Tanh(_) => {
1193 let out_data: Vec<f32> = current.data().iter().map(|&v| v.tanh()).collect();
1194 Tensor::from_vec(current.shape().to_vec(), out_data)?
1195 }
1196 ModelLayer::LeakyReLU(l) => {
1197 let slope = l.negative_slope();
1198 let out_data: Vec<f32> = current
1199 .data()
1200 .iter()
1201 .map(|&v| if v >= 0.0 { v } else { slope * v })
1202 .collect();
1203 Tensor::from_vec(current.shape().to_vec(), out_data)?
1204 }
1205 ModelLayer::GlobalAvgPool2d(l) => l.forward_inference(¤t)?,
1206 ModelLayer::DepthwiseConv2d(l) => l.forward_inference(¤t)?,
1207 ModelLayer::SeparableConv2d(l) => l.forward_inference(¤t)?,
1208 ModelLayer::Dropout(d) => {
1209 if self.training && d.rate() > 0.0 {
1210 let scale = 1.0 / (1.0 - d.rate());
1211 let scaled: Vec<f32> = current.data().iter().map(|&v| v * scale).collect();
1212 Tensor::from_vec(current.shape().to_vec(), scaled)?
1213 } else {
1214 current
1215 }
1216 }
1217 ModelLayer::Conv1d(l) => l.forward_inference(¤t)?,
1218 ModelLayer::Conv3d(l) => l.forward_inference(¤t)?,
1219 ModelLayer::ConvTranspose2d(l) => l.forward_inference(¤t)?,
1220 ModelLayer::AdaptiveAvgPool2d(l) => l.forward_inference(¤t)?,
1221 ModelLayer::AdaptiveMaxPool2d(l) => l.forward_inference(¤t)?,
1222 ModelLayer::InstanceNorm(l) => l.forward_inference(¤t)?,
1223 ModelLayer::PixelShuffle(l) => l.forward_inference(¤t)?,
1224 ModelLayer::Upsample(l) => l.forward_inference(¤t)?,
1225 ModelLayer::GELU(l) => l.forward_inference(¤t)?,
1226 ModelLayer::SiLU(l) => l.forward_inference(¤t)?,
1227 ModelLayer::Mish(l) => l.forward_inference(¤t)?,
1228 ModelLayer::PReLU(l) => l.forward_inference(¤t)?,
1229 ModelLayer::ResidualBlock(l) => l.forward_inference(¤t)?,
1230 ModelLayer::Rnn(l) => l.forward_inference(¤t)?,
1231 ModelLayer::Lstm(l) => l.forward_inference(¤t)?,
1232 ModelLayer::Gru(l) => l.forward_inference(¤t)?,
1233 ModelLayer::MultiHeadAttention(l) => l.forward_inference(¤t)?,
1234 ModelLayer::TransformerEncoder(l) => l.forward_inference(¤t)?,
1235 ModelLayer::FeedForward(l) => l.forward_inference(¤t)?,
1236 ModelLayer::DeformableConv2d(l) => l.forward_inference(¤t)?,
1237 ModelLayer::Linear(l) => l.forward_inference(¤t)?,
1238 ModelLayer::Embedding(_)
1239 | ModelLayer::LayerNorm(_)
1240 | ModelLayer::GroupNorm(_)
1241 | ModelLayer::LoraLinear(_) => return Err(ModelError::GraphOnlyLayer),
1242 };
1243 }
1244 Ok(current)
1245 }
1246
1247 pub fn register_cnn_params(&mut self, graph: &mut Graph) {
1253 for layer in &mut self.layers {
1254 match layer {
1255 ModelLayer::Conv2d(l) if l.weight_node().is_none() => l.register_params(graph),
1256 ModelLayer::BatchNorm2d(l) if l.gamma_node().is_none() => l.register_params(graph),
1257 ModelLayer::DepthwiseConv2d(l) if l.weight_node().is_none() => {
1258 l.register_params(graph)
1259 }
1260 ModelLayer::SeparableConv2d(l) if l.depthwise().weight_node().is_none() => {
1261 l.register_params(graph)
1262 }
1263 ModelLayer::Conv1d(l) if l.weight_node().is_none() => l.register_params(graph),
1264 ModelLayer::Conv3d(l) if l.weight_node().is_none() => l.register_params(graph),
1265 ModelLayer::MultiHeadAttention(l) if l.w_q_node().is_none() => {
1266 l.register_params(graph)
1267 }
1268 ModelLayer::FeedForward(l) if l.w1_node().is_none() => l.register_params(graph),
1269 ModelLayer::TransformerEncoder(l) if l.ln1_gamma_node().is_none() => {
1270 l.register_params(graph)
1271 }
1272 ModelLayer::Rnn(l) if l.w_ih_node().is_none() => l.register_params(graph),
1273 ModelLayer::Lstm(l) if l.w_ih_node().is_none() => l.register_params(graph),
1274 ModelLayer::Gru(l) if l.w_ih_node().is_none() => l.register_params(graph),
1275 ModelLayer::DeformableConv2d(l) if l.weight_node().is_none() => {
1276 l.register_params(graph)
1277 }
1278 ModelLayer::ConvTranspose2d(l) if l.weight_node().is_none() => {
1279 l.register_params(graph)
1280 }
1281 ModelLayer::InstanceNorm(l) if l.gamma_node().is_none() => l.register_params(graph),
1282 ModelLayer::PReLU(l) if l.alpha_node().is_none() => l.register_params(graph),
1283 _ => {}
1284 }
1285 }
1286 self.persistent_node_count = graph.node_count();
1287 }
1288
1289 pub fn sync_cnn_from_graph(&mut self, graph: &Graph) -> Result<(), ModelError> {
1291 for layer in &mut self.layers {
1292 match layer {
1293 ModelLayer::Conv2d(l) => l.sync_from_graph(graph)?,
1294 ModelLayer::BatchNorm2d(l) => l.sync_from_graph(graph)?,
1295 ModelLayer::DepthwiseConv2d(l) => l.sync_from_graph(graph)?,
1296 ModelLayer::SeparableConv2d(l) => l.sync_from_graph(graph)?,
1297 _ => {}
1298 }
1299 }
1300 Ok(())
1301 }
1302
1303 pub fn trainable_nodes(&self) -> Vec<NodeId> {
1304 let mut out = Vec::new();
1305 for layer in &self.layers {
1306 match layer {
1307 ModelLayer::Linear(linear) => {
1308 out.extend(linear.trainable_nodes());
1309 }
1310 ModelLayer::Conv2d(conv) => {
1311 if let Some(w) = conv.weight_node() {
1312 out.push(w);
1313 }
1314 if let Some(b) = conv.bias_node() {
1315 out.push(b);
1316 }
1317 }
1318 ModelLayer::BatchNorm2d(bn) => {
1319 if let Some(g) = bn.gamma_node() {
1320 out.push(g);
1321 }
1322 if let Some(b) = bn.beta_node() {
1323 out.push(b);
1324 }
1325 }
1326 ModelLayer::DepthwiseConv2d(dw) => {
1327 if let Some(w) = dw.weight_node() {
1328 out.push(w);
1329 }
1330 if let Some(b) = dw.bias_node() {
1331 out.push(b);
1332 }
1333 }
1334 ModelLayer::SeparableConv2d(sep) => {
1335 if let Some(w) = sep.depthwise().weight_node() {
1336 out.push(w);
1337 }
1338 if let Some(b) = sep.depthwise().bias_node() {
1339 out.push(b);
1340 }
1341 if let Some(w) = sep.pointwise().weight_node() {
1342 out.push(w);
1343 }
1344 if let Some(b) = sep.pointwise().bias_node() {
1345 out.push(b);
1346 }
1347 }
1348 ModelLayer::LoraLinear(lora) => {
1349 out.extend(lora.trainable_params());
1350 }
1351 _ => {}
1352 }
1353 }
1354 out
1355 }
1356
1357 pub fn persistent_node_count(&self) -> usize {
1358 self.persistent_node_count
1359 }
1360
1361 pub fn checkpoint(&self, graph: &Graph) -> Result<SequentialCheckpoint, ModelError> {
1362 let mut layers = Vec::with_capacity(self.layers.len());
1363 for layer in &self.layers {
1364 match layer {
1365 ModelLayer::Linear(linear) => {
1366 let weight = graph
1367 .value(linear.weight_node().expect("linear layer has weight node"))?
1368 .clone();
1369 let bias = graph
1370 .value(linear.bias_node().expect("linear layer has bias node"))?
1371 .clone();
1372 layers.push(LayerCheckpoint::Linear {
1373 in_features: linear.in_features(),
1374 out_features: linear.out_features(),
1375 weight: TensorSnapshot::from_tensor(&weight),
1376 bias: TensorSnapshot::from_tensor(&bias),
1377 });
1378 }
1379 ModelLayer::ReLU(_) => layers.push(LayerCheckpoint::ReLU),
1380 ModelLayer::LeakyReLU(layer) => layers.push(LayerCheckpoint::LeakyReLU {
1381 negative_slope: layer.negative_slope(),
1382 }),
1383 ModelLayer::Sigmoid(_) => layers.push(LayerCheckpoint::Sigmoid),
1384 ModelLayer::Tanh(_) => layers.push(LayerCheckpoint::Tanh),
1385 ModelLayer::Dropout(layer) => {
1386 layers.push(LayerCheckpoint::Dropout { rate: layer.rate() })
1387 }
1388 ModelLayer::Conv2d(layer) => layers.push(LayerCheckpoint::Conv2d {
1389 in_channels: layer.in_channels(),
1390 out_channels: layer.out_channels(),
1391 kernel_h: layer.kernel_h(),
1392 kernel_w: layer.kernel_w(),
1393 stride_h: layer.stride_h(),
1394 stride_w: layer.stride_w(),
1395 weight: TensorSnapshot::from_tensor(layer.weight()),
1396 bias: layer.bias().map(TensorSnapshot::from_tensor),
1397 }),
1398 ModelLayer::BatchNorm2d(layer) => layers.push(LayerCheckpoint::BatchNorm2d {
1399 num_features: layer.num_features(),
1400 epsilon: layer.epsilon(),
1401 gamma: TensorSnapshot::from_tensor(layer.gamma()),
1402 beta: TensorSnapshot::from_tensor(layer.beta()),
1403 running_mean: TensorSnapshot::from_tensor(layer.running_mean()),
1404 running_var: TensorSnapshot::from_tensor(layer.running_var()),
1405 }),
1406 ModelLayer::MaxPool2d(layer) => layers.push(LayerCheckpoint::MaxPool2d {
1407 kernel_h: layer.kernel_h(),
1408 kernel_w: layer.kernel_w(),
1409 stride_h: layer.stride_h(),
1410 stride_w: layer.stride_w(),
1411 }),
1412 ModelLayer::AvgPool2d(layer) => layers.push(LayerCheckpoint::AvgPool2d {
1413 kernel_h: layer.kernel_h(),
1414 kernel_w: layer.kernel_w(),
1415 stride_h: layer.stride_h(),
1416 stride_w: layer.stride_w(),
1417 }),
1418 ModelLayer::Flatten(_) => layers.push(LayerCheckpoint::Flatten),
1419 ModelLayer::GlobalAvgPool2d(_) => layers.push(LayerCheckpoint::GlobalAvgPool2d),
1420 ModelLayer::Softmax(_) => layers.push(LayerCheckpoint::Softmax),
1421 ModelLayer::Embedding(layer) => {
1422 let w = graph.value(layer.weight_node())?;
1423 layers.push(LayerCheckpoint::Embedding {
1424 num_embeddings: layer.num_embeddings(),
1425 embedding_dim: layer.embedding_dim(),
1426 weight: TensorSnapshot::from_tensor(w),
1427 });
1428 }
1429 ModelLayer::LayerNorm(layer) => {
1430 let g = graph.value(layer.gamma_node())?;
1431 let b = graph.value(layer.beta_node())?;
1432 layers.push(LayerCheckpoint::LayerNorm {
1433 normalized_shape: layer.normalized_shape(),
1434 eps: 1e-5,
1435 gamma: TensorSnapshot::from_tensor(g),
1436 beta: TensorSnapshot::from_tensor(b),
1437 });
1438 }
1439 ModelLayer::GroupNorm(layer) => {
1440 let g = graph.value(layer.gamma_node())?;
1441 let b = graph.value(layer.beta_node())?;
1442 layers.push(LayerCheckpoint::GroupNorm {
1443 num_groups: layer.num_groups(),
1444 num_channels: layer.num_channels(),
1445 eps: 1e-5,
1446 gamma: TensorSnapshot::from_tensor(g),
1447 beta: TensorSnapshot::from_tensor(b),
1448 });
1449 }
1450 ModelLayer::DepthwiseConv2d(layer) => {
1451 layers.push(LayerCheckpoint::DepthwiseConv2d {
1452 channels: layer.channels(),
1453 kernel_h: layer.kernel_h(),
1454 kernel_w: layer.kernel_w(),
1455 stride_h: layer.stride_h(),
1456 stride_w: layer.stride_w(),
1457 weight: TensorSnapshot::from_tensor(layer.weight()),
1458 bias: layer.bias().map(TensorSnapshot::from_tensor),
1459 });
1460 }
1461 ModelLayer::SeparableConv2d(layer) => {
1462 layers.push(LayerCheckpoint::SeparableConv2d {
1463 in_channels: layer.in_channels(),
1464 out_channels: layer.out_channels(),
1465 kernel_h: layer.kernel_h(),
1466 kernel_w: layer.kernel_w(),
1467 stride_h: layer.stride_h(),
1468 stride_w: layer.stride_w(),
1469 depthwise_weight: TensorSnapshot::from_tensor(layer.depthwise().weight()),
1470 pointwise_weight: TensorSnapshot::from_tensor(layer.pointwise().weight()),
1471 bias: layer.pointwise().bias().map(TensorSnapshot::from_tensor),
1472 });
1473 }
1474 ModelLayer::LoraLinear(lora) => {
1475 let merged_weight = lora.merge(graph)?;
1477 let bias = if let Some(bias_node) = lora.bias {
1478 graph.value(bias_node)?.clone()
1479 } else {
1480 Tensor::zeros(vec![lora.out_features])?
1481 };
1482 layers.push(LayerCheckpoint::Linear {
1483 in_features: lora.in_features,
1484 out_features: lora.out_features,
1485 weight: TensorSnapshot::from_tensor(&merged_weight),
1486 bias: TensorSnapshot::from_tensor(&bias),
1487 });
1488 }
1489 ModelLayer::Conv1d(_)
1491 | ModelLayer::Conv3d(_)
1492 | ModelLayer::ConvTranspose2d(_)
1493 | ModelLayer::AdaptiveAvgPool2d(_)
1494 | ModelLayer::AdaptiveMaxPool2d(_)
1495 | ModelLayer::InstanceNorm(_)
1496 | ModelLayer::PixelShuffle(_)
1497 | ModelLayer::Upsample(_)
1498 | ModelLayer::GELU(_)
1499 | ModelLayer::SiLU(_)
1500 | ModelLayer::Mish(_)
1501 | ModelLayer::PReLU(_)
1502 | ModelLayer::ResidualBlock(_)
1503 | ModelLayer::Rnn(_)
1504 | ModelLayer::Lstm(_)
1505 | ModelLayer::Gru(_)
1506 | ModelLayer::MultiHeadAttention(_)
1507 | ModelLayer::TransformerEncoder(_)
1508 | ModelLayer::FeedForward(_)
1509 | ModelLayer::DeformableConv2d(_) => {
1510 return Err(ModelError::InferenceOnlyLayer);
1511 }
1512 }
1513 }
1514 Ok(SequentialCheckpoint { layers })
1515 }
1516
1517 pub fn from_checkpoint(
1518 graph: &mut Graph,
1519 checkpoint: &SequentialCheckpoint,
1520 ) -> Result<Self, ModelError> {
1521 let mut model = Self::new(graph);
1522 for layer in &checkpoint.layers {
1523 match layer {
1524 LayerCheckpoint::Linear {
1525 in_features,
1526 out_features,
1527 weight,
1528 bias,
1529 } => {
1530 model.add_linear(
1531 graph,
1532 *in_features,
1533 *out_features,
1534 weight.clone().into_tensor()?,
1535 bias.clone().into_tensor()?,
1536 )?;
1537 }
1538 LayerCheckpoint::ReLU => model.add_relu(),
1539 LayerCheckpoint::LeakyReLU { negative_slope } => {
1540 model.add_leaky_relu(*negative_slope)?
1541 }
1542 LayerCheckpoint::Sigmoid => model.add_sigmoid(),
1543 LayerCheckpoint::Tanh => model.add_tanh(),
1544 LayerCheckpoint::Dropout { rate } => model.add_dropout(*rate)?,
1545 LayerCheckpoint::Conv2d {
1546 in_channels,
1547 out_channels,
1548 kernel_h,
1549 kernel_w,
1550 stride_h,
1551 stride_w,
1552 weight,
1553 bias,
1554 } => {
1555 let bias_tensor = match bias {
1556 Some(b) => Some(b.clone().into_tensor()?),
1557 None => None,
1558 };
1559 model.add_conv2d(
1560 *in_channels,
1561 *out_channels,
1562 *kernel_h,
1563 *kernel_w,
1564 *stride_h,
1565 *stride_w,
1566 weight.clone().into_tensor()?,
1567 bias_tensor,
1568 )?;
1569 }
1570 LayerCheckpoint::BatchNorm2d {
1571 num_features,
1572 epsilon,
1573 gamma,
1574 beta,
1575 running_mean,
1576 running_var,
1577 } => {
1578 model.add_batch_norm2d(
1579 *num_features,
1580 *epsilon,
1581 gamma.clone().into_tensor()?,
1582 beta.clone().into_tensor()?,
1583 running_mean.clone().into_tensor()?,
1584 running_var.clone().into_tensor()?,
1585 )?;
1586 }
1587 LayerCheckpoint::MaxPool2d {
1588 kernel_h,
1589 kernel_w,
1590 stride_h,
1591 stride_w,
1592 } => model.add_max_pool2d(*kernel_h, *kernel_w, *stride_h, *stride_w)?,
1593 LayerCheckpoint::AvgPool2d {
1594 kernel_h,
1595 kernel_w,
1596 stride_h,
1597 stride_w,
1598 } => model.add_avg_pool2d(*kernel_h, *kernel_w, *stride_h, *stride_w)?,
1599 LayerCheckpoint::Flatten => model.add_flatten(),
1600 LayerCheckpoint::GlobalAvgPool2d => model.add_global_avg_pool2d(),
1601 LayerCheckpoint::Softmax => model.add_softmax(),
1602 LayerCheckpoint::Embedding {
1603 num_embeddings,
1604 embedding_dim,
1605 weight,
1606 } => {
1607 model.add_embedding(
1608 graph,
1609 *num_embeddings,
1610 *embedding_dim,
1611 weight.clone().into_tensor()?,
1612 )?;
1613 }
1614 LayerCheckpoint::LayerNorm {
1615 normalized_shape,
1616 eps,
1617 gamma: _,
1618 beta: _,
1619 } => {
1620 model.add_layer_norm(graph, *normalized_shape, *eps)?;
1621 }
1622 LayerCheckpoint::GroupNorm {
1623 num_groups,
1624 num_channels,
1625 eps,
1626 gamma: _,
1627 beta: _,
1628 } => {
1629 model.add_group_norm(graph, *num_groups, *num_channels, *eps)?;
1630 }
1631 LayerCheckpoint::DepthwiseConv2d {
1632 channels,
1633 kernel_h,
1634 kernel_w,
1635 stride_h,
1636 stride_w,
1637 weight,
1638 bias,
1639 } => {
1640 let bias_tensor = match bias {
1641 Some(b) => Some(b.clone().into_tensor()?),
1642 None => None,
1643 };
1644 model.add_depthwise_conv2d(
1645 *channels,
1646 *kernel_h,
1647 *kernel_w,
1648 *stride_h,
1649 *stride_w,
1650 weight.clone().into_tensor()?,
1651 bias_tensor,
1652 )?;
1653 }
1654 LayerCheckpoint::SeparableConv2d {
1655 in_channels,
1656 out_channels,
1657 kernel_h,
1658 kernel_w,
1659 stride_h,
1660 stride_w,
1661 depthwise_weight,
1662 pointwise_weight,
1663 bias,
1664 } => {
1665 let bias_tensor = match bias {
1666 Some(b) => Some(b.clone().into_tensor()?),
1667 None => None,
1668 };
1669 model.add_separable_conv2d(
1670 *in_channels,
1671 *out_channels,
1672 *kernel_h,
1673 *kernel_w,
1674 *stride_h,
1675 *stride_w,
1676 depthwise_weight.clone().into_tensor()?,
1677 pointwise_weight.clone().into_tensor()?,
1678 bias_tensor,
1679 )?;
1680 }
1681 }
1682 }
1683 Ok(model)
1684 }
1685}
1686
1687fn format_param_count(n: usize) -> String {
1688 if n >= 1_000_000 {
1689 format!("{:.1}M", n as f64 / 1_000_000.0)
1690 } else if n >= 1_000 {
1691 format!("{:.1}K", n as f64 / 1_000.0)
1692 } else {
1693 n.to_string()
1694 }
1695}