Skip to main content

yscv_model/
sequential.rs

1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use crate::{
5    AvgPool2dLayer, BatchNorm2dLayer, Conv2dLayer, DeformableConv2dLayer, DepthwiseConv2dLayer,
6    DropoutLayer, EmbeddingLayer, FeedForwardLayer, FlattenLayer, GlobalAvgPool2dLayer,
7    GroupNormLayer, GruLayer, LayerCheckpoint, LayerNormLayer, LeakyReLULayer, LinearLayer,
8    LoraConfig, LoraLinear, LstmLayer, MaxPool2dLayer, ModelError, ModelLayer,
9    MultiHeadAttentionLayer, ReLULayer, ResidualBlock, RnnLayer, SeparableConv2dLayer,
10    SequentialCheckpoint, SigmoidLayer, SoftmaxLayer, TanhLayer, TensorSnapshot,
11    TransformerEncoderLayer, optimize_sequential,
12};
13
14/// Ordered stack of layers executed one-by-one.
15#[derive(Debug, Clone)]
16pub struct SequentialModel {
17    layers: Vec<ModelLayer>,
18    frozen: Vec<bool>,
19    persistent_node_count: usize,
20    training: bool,
21}
22
23impl SequentialModel {
24    /// Creates an empty model and records current graph prefix as persistent base.
25    pub fn new(graph: &Graph) -> Self {
26        Self {
27            layers: Vec::new(),
28            frozen: Vec::new(),
29            persistent_node_count: graph.node_count(),
30            training: true,
31        }
32    }
33
34    pub fn add_linear(
35        &mut self,
36        graph: &mut Graph,
37        in_features: usize,
38        out_features: usize,
39        weight_init: Tensor,
40        bias_init: Tensor,
41    ) -> Result<(), ModelError> {
42        let layer = LinearLayer::new(graph, in_features, out_features, weight_init, bias_init)?;
43        self.layers.push(ModelLayer::Linear(layer));
44        self.frozen.push(false);
45        self.persistent_node_count = graph.node_count();
46        Ok(())
47    }
48
49    pub fn add_linear_zero(
50        &mut self,
51        graph: &mut Graph,
52        in_features: usize,
53        out_features: usize,
54    ) -> Result<(), ModelError> {
55        let layer = LinearLayer::zero_init(graph, in_features, out_features)?;
56        self.layers.push(ModelLayer::Linear(layer));
57        self.frozen.push(false);
58        self.persistent_node_count = graph.node_count();
59        Ok(())
60    }
61
62    pub fn add_relu(&mut self) {
63        self.layers.push(ModelLayer::ReLU(ReLULayer::new()));
64        self.frozen.push(false);
65    }
66
67    pub fn add_leaky_relu(&mut self, negative_slope: f32) -> Result<(), ModelError> {
68        let layer = LeakyReLULayer::new(negative_slope)?;
69        self.layers.push(ModelLayer::LeakyReLU(layer));
70        self.frozen.push(false);
71        Ok(())
72    }
73
74    pub fn add_sigmoid(&mut self) {
75        self.layers.push(ModelLayer::Sigmoid(SigmoidLayer::new()));
76        self.frozen.push(false);
77    }
78
79    pub fn add_tanh(&mut self) {
80        self.layers.push(ModelLayer::Tanh(TanhLayer::new()));
81        self.frozen.push(false);
82    }
83
84    pub fn add_gelu(&mut self) {
85        self.layers.push(ModelLayer::GELU(crate::GELULayer::new()));
86        self.frozen.push(false);
87    }
88
89    pub fn add_silu(&mut self) {
90        self.layers.push(ModelLayer::SiLU(crate::SiLULayer::new()));
91        self.frozen.push(false);
92    }
93
94    pub fn add_mish(&mut self) {
95        self.layers.push(ModelLayer::Mish(crate::MishLayer::new()));
96        self.frozen.push(false);
97    }
98
99    pub fn add_prelu(&mut self, alpha: Vec<f32>) {
100        self.layers
101            .push(ModelLayer::PReLU(crate::PReLULayer::new(alpha)));
102        self.frozen.push(false);
103    }
104
105    pub fn add_dropout(&mut self, rate: f32) -> Result<(), ModelError> {
106        let layer = DropoutLayer::new(rate)?;
107        self.layers.push(ModelLayer::Dropout(layer));
108        self.frozen.push(false);
109        Ok(())
110    }
111
112    #[allow(clippy::too_many_arguments)]
113    pub fn add_conv2d(
114        &mut self,
115        in_channels: usize,
116        out_channels: usize,
117        kernel_h: usize,
118        kernel_w: usize,
119        stride_h: usize,
120        stride_w: usize,
121        weight: Tensor,
122        bias: Option<Tensor>,
123    ) -> Result<(), ModelError> {
124        let layer = Conv2dLayer::new(
125            in_channels,
126            out_channels,
127            kernel_h,
128            kernel_w,
129            stride_h,
130            stride_w,
131            weight,
132            bias,
133        )?;
134        self.layers.push(ModelLayer::Conv2d(layer));
135        self.frozen.push(false);
136        Ok(())
137    }
138
139    #[allow(clippy::too_many_arguments)]
140    pub fn add_conv2d_zero(
141        &mut self,
142        in_channels: usize,
143        out_channels: usize,
144        kernel_h: usize,
145        kernel_w: usize,
146        stride_h: usize,
147        stride_w: usize,
148        use_bias: bool,
149    ) -> Result<(), ModelError> {
150        let layer = Conv2dLayer::zero_init(
151            in_channels,
152            out_channels,
153            kernel_h,
154            kernel_w,
155            stride_h,
156            stride_w,
157            use_bias,
158        )?;
159        self.layers.push(ModelLayer::Conv2d(layer));
160        self.frozen.push(false);
161        Ok(())
162    }
163
164    #[allow(clippy::too_many_arguments)]
165    pub fn add_deformable_conv2d(
166        &mut self,
167        in_channels: usize,
168        out_channels: usize,
169        kernel_h: usize,
170        kernel_w: usize,
171        stride: usize,
172        padding: usize,
173        weight: Tensor,
174        offset_weight: Tensor,
175        bias: Option<Tensor>,
176    ) -> Result<(), ModelError> {
177        let layer = DeformableConv2dLayer::new(
178            in_channels,
179            out_channels,
180            kernel_h,
181            kernel_w,
182            stride,
183            padding,
184            weight,
185            offset_weight,
186            bias,
187        )?;
188        self.layers.push(ModelLayer::DeformableConv2d(layer));
189        self.frozen.push(false);
190        Ok(())
191    }
192
193    pub fn add_deformable_conv2d_zero(
194        &mut self,
195        in_channels: usize,
196        out_channels: usize,
197        kernel_h: usize,
198        kernel_w: usize,
199        stride: usize,
200        padding: usize,
201        use_bias: bool,
202    ) -> Result<(), ModelError> {
203        let layer = DeformableConv2dLayer::zero_init(
204            in_channels,
205            out_channels,
206            kernel_h,
207            kernel_w,
208            stride,
209            padding,
210            use_bias,
211        )?;
212        self.layers.push(ModelLayer::DeformableConv2d(layer));
213        self.frozen.push(false);
214        Ok(())
215    }
216
217    #[allow(clippy::too_many_arguments)]
218    pub fn add_depthwise_conv2d(
219        &mut self,
220        channels: usize,
221        kernel_h: usize,
222        kernel_w: usize,
223        stride_h: usize,
224        stride_w: usize,
225        weight: Tensor,
226        bias: Option<Tensor>,
227    ) -> Result<(), ModelError> {
228        let layer = DepthwiseConv2dLayer::new(
229            channels, kernel_h, kernel_w, stride_h, stride_w, weight, bias,
230        )?;
231        self.layers.push(ModelLayer::DepthwiseConv2d(layer));
232        self.frozen.push(false);
233        Ok(())
234    }
235
236    pub fn add_depthwise_conv2d_zero(
237        &mut self,
238        channels: usize,
239        kernel_h: usize,
240        kernel_w: usize,
241        stride_h: usize,
242        stride_w: usize,
243        use_bias: bool,
244    ) -> Result<(), ModelError> {
245        let layer = DepthwiseConv2dLayer::zero_init(
246            channels, kernel_h, kernel_w, stride_h, stride_w, use_bias,
247        )?;
248        self.layers.push(ModelLayer::DepthwiseConv2d(layer));
249        self.frozen.push(false);
250        Ok(())
251    }
252
253    #[allow(clippy::too_many_arguments)]
254    pub fn add_separable_conv2d(
255        &mut self,
256        in_channels: usize,
257        out_channels: usize,
258        kernel_h: usize,
259        kernel_w: usize,
260        stride_h: usize,
261        stride_w: usize,
262        depthwise_weight: Tensor,
263        pointwise_weight: Tensor,
264        bias: Option<Tensor>,
265    ) -> Result<(), ModelError> {
266        let layer = SeparableConv2dLayer::new(
267            in_channels,
268            out_channels,
269            kernel_h,
270            kernel_w,
271            stride_h,
272            stride_w,
273            depthwise_weight,
274            pointwise_weight,
275            bias,
276        )?;
277        self.layers.push(ModelLayer::SeparableConv2d(layer));
278        self.frozen.push(false);
279        Ok(())
280    }
281
282    #[allow(clippy::too_many_arguments)]
283    pub fn add_separable_conv2d_zero(
284        &mut self,
285        in_channels: usize,
286        out_channels: usize,
287        kernel_h: usize,
288        kernel_w: usize,
289        stride_h: usize,
290        stride_w: usize,
291        use_bias: bool,
292    ) -> Result<(), ModelError> {
293        let layer = SeparableConv2dLayer::zero_init(
294            in_channels,
295            out_channels,
296            kernel_h,
297            kernel_w,
298            stride_h,
299            stride_w,
300            use_bias,
301        )?;
302        self.layers.push(ModelLayer::SeparableConv2d(layer));
303        self.frozen.push(false);
304        Ok(())
305    }
306
307    pub fn add_batch_norm2d(
308        &mut self,
309        num_features: usize,
310        epsilon: f32,
311        gamma: Tensor,
312        beta: Tensor,
313        running_mean: Tensor,
314        running_var: Tensor,
315    ) -> Result<(), ModelError> {
316        let layer = BatchNorm2dLayer::new(
317            num_features,
318            epsilon,
319            gamma,
320            beta,
321            running_mean,
322            running_var,
323        )?;
324        self.layers.push(ModelLayer::BatchNorm2d(layer));
325        self.frozen.push(false);
326        Ok(())
327    }
328
329    pub fn add_batch_norm2d_identity(
330        &mut self,
331        num_features: usize,
332        epsilon: f32,
333    ) -> Result<(), ModelError> {
334        let layer = BatchNorm2dLayer::identity_init(num_features, epsilon)?;
335        self.layers.push(ModelLayer::BatchNorm2d(layer));
336        self.frozen.push(false);
337        Ok(())
338    }
339
340    pub fn add_max_pool2d(
341        &mut self,
342        kernel_h: usize,
343        kernel_w: usize,
344        stride_h: usize,
345        stride_w: usize,
346    ) -> Result<(), ModelError> {
347        let layer = MaxPool2dLayer::new(kernel_h, kernel_w, stride_h, stride_w)?;
348        self.layers.push(ModelLayer::MaxPool2d(layer));
349        self.frozen.push(false);
350        Ok(())
351    }
352
353    pub fn add_avg_pool2d(
354        &mut self,
355        kernel_h: usize,
356        kernel_w: usize,
357        stride_h: usize,
358        stride_w: usize,
359    ) -> Result<(), ModelError> {
360        let layer = AvgPool2dLayer::new(kernel_h, kernel_w, stride_h, stride_w)?;
361        self.layers.push(ModelLayer::AvgPool2d(layer));
362        self.frozen.push(false);
363        Ok(())
364    }
365
366    pub fn add_flatten(&mut self) {
367        self.layers.push(ModelLayer::Flatten(FlattenLayer::new()));
368        self.frozen.push(false);
369    }
370
371    pub fn add_global_avg_pool2d(&mut self) {
372        self.layers
373            .push(ModelLayer::GlobalAvgPool2d(GlobalAvgPool2dLayer::new()));
374        self.frozen.push(false);
375    }
376
377    pub fn add_softmax(&mut self) {
378        self.layers.push(ModelLayer::Softmax(SoftmaxLayer::new()));
379        self.frozen.push(false);
380    }
381
382    pub fn add_embedding(
383        &mut self,
384        graph: &mut Graph,
385        num_embeddings: usize,
386        embedding_dim: usize,
387        weight_init: Tensor,
388    ) -> Result<(), ModelError> {
389        let layer = EmbeddingLayer::new(graph, num_embeddings, embedding_dim, weight_init)?;
390        self.layers.push(ModelLayer::Embedding(layer));
391        self.frozen.push(false);
392        Ok(())
393    }
394
395    pub fn add_layer_norm(
396        &mut self,
397        graph: &mut Graph,
398        normalized_shape: usize,
399        eps: f32,
400    ) -> Result<(), ModelError> {
401        let layer = LayerNormLayer::new(graph, normalized_shape, eps)?;
402        self.layers.push(ModelLayer::LayerNorm(layer));
403        self.frozen.push(false);
404        Ok(())
405    }
406
407    pub fn add_group_norm(
408        &mut self,
409        graph: &mut Graph,
410        num_groups: usize,
411        num_channels: usize,
412        eps: f32,
413    ) -> Result<(), ModelError> {
414        let layer = GroupNormLayer::new(graph, num_groups, num_channels, eps)?;
415        self.layers.push(ModelLayer::GroupNorm(layer));
416        self.frozen.push(false);
417        Ok(())
418    }
419
420    /// Replace all Linear layers with LoRA-adapted versions.
421    /// The original weights are frozen; only the low-rank A and B matrices are trainable.
422    /// Returns the number of layers converted.
423    pub fn apply_lora(
424        &mut self,
425        graph: &mut Graph,
426        config: &LoraConfig,
427    ) -> Result<usize, ModelError> {
428        let mut count = 0;
429        for layer in self.layers.iter_mut() {
430            if let ModelLayer::Linear(linear) = layer {
431                let in_features = linear.in_features();
432                let out_features = linear.out_features();
433                let weight_node = linear.weight_node().expect("linear layer has weight node");
434                let bias_node = linear.bias_node().expect("linear layer has bias node");
435                let lora = LoraLinear::from_linear(
436                    graph,
437                    weight_node,
438                    bias_node,
439                    in_features,
440                    out_features,
441                    config,
442                )?;
443                *layer = ModelLayer::LoraLinear(lora);
444                count += 1;
445            }
446        }
447        self.persistent_node_count = graph.node_count();
448        Ok(count)
449    }
450
451    /// Merge all LoRA layers back into regular Linear layers.
452    /// Call this after fine-tuning for inference without overhead.
453    /// Returns the number of layers merged.
454    pub fn merge_lora(&mut self, graph: &mut Graph) -> Result<usize, ModelError> {
455        let mut count = 0;
456        for layer in self.layers.iter_mut() {
457            if let ModelLayer::LoraLinear(lora) = layer {
458                let merged_weight = lora.merge(graph)?;
459                let in_features = lora.in_features;
460                let out_features = lora.out_features;
461                let bias_tensor = if let Some(bias_node) = lora.bias {
462                    graph.value(bias_node)?.clone()
463                } else {
464                    Tensor::zeros(vec![out_features])?
465                };
466                let linear =
467                    LinearLayer::new(graph, in_features, out_features, merged_weight, bias_tensor)?;
468                *layer = ModelLayer::Linear(linear);
469                count += 1;
470            }
471        }
472        self.persistent_node_count = graph.node_count();
473        Ok(count)
474    }
475
476    /// Optimize the model by fusing Conv+BN layers.
477    /// This reduces the number of layers and operations for faster inference.
478    /// Should be called before inference, after training is complete.
479    /// Returns the number of fusions performed.
480    pub fn optimize(&mut self, graph: &mut Graph) -> usize {
481        let before = self.layers.len();
482        let optimized = optimize_sequential(self, graph);
483        let after = optimized.layers.len();
484        *self = optimized;
485        before - after
486    }
487
488    /// Set training/eval mode for all dropout layers.
489    pub fn set_training(&mut self, training: bool) {
490        self.training = training;
491        for layer in &mut self.layers {
492            if let ModelLayer::Dropout(d) = layer {
493                d.set_training(training);
494            }
495        }
496    }
497
498    /// Switch to evaluation mode (disables dropout).
499    pub fn eval(&mut self) {
500        self.set_training(false);
501    }
502
503    /// Switch to training mode (enables dropout).
504    pub fn train_mode(&mut self) {
505        self.set_training(true);
506    }
507
508    /// Returns whether the model is in training mode.
509    pub fn is_training(&self) -> bool {
510        self.training
511    }
512
513    /// Print a human-readable summary of the model architecture.
514    ///
515    /// Shows each layer's type, output shape info, and parameter count.
516    /// Returns a `String` so callers can print or log it.
517    pub fn summary(&self) -> String {
518        use std::fmt::Write;
519
520        let mut out = String::new();
521        let sep = "-".repeat(65);
522        writeln!(out, "{sep}").expect("write to String");
523        writeln!(
524            out,
525            " {:>3}  {:<25} {:>20}  {:>8}",
526            "#", "Layer", "Details", "Params"
527        )
528        .expect("write to String");
529        writeln!(out, "{sep}").expect("write to String");
530
531        let mut total_params = 0usize;
532
533        for (i, layer) in self.layers.iter().enumerate() {
534            let (name, details, params) = match layer {
535                ModelLayer::Linear(l) => {
536                    let p = l.in_features() * l.out_features() + l.out_features();
537                    (
538                        "Linear".to_string(),
539                        format!("{}→{}", l.in_features(), l.out_features()),
540                        p,
541                    )
542                }
543                ModelLayer::Conv2d(l) => {
544                    let p = l.out_channels() * l.in_channels() * l.kernel_h() * l.kernel_w()
545                        + if l.bias().is_some() {
546                            l.out_channels()
547                        } else {
548                            0
549                        };
550                    (
551                        "Conv2d".to_string(),
552                        format!(
553                            "{}→{} {}x{} s{}",
554                            l.in_channels(),
555                            l.out_channels(),
556                            l.kernel_h(),
557                            l.kernel_w(),
558                            l.stride_h(),
559                        ),
560                        p,
561                    )
562                }
563                ModelLayer::BatchNorm2d(l) => {
564                    let p = l.num_features() * 2; // gamma + beta
565                    (
566                        "BatchNorm2d".to_string(),
567                        format!("{}", l.num_features()),
568                        p,
569                    )
570                }
571                ModelLayer::MaxPool2d(l) => (
572                    "MaxPool2d".to_string(),
573                    format!("{}x{} s{}", l.kernel_h(), l.kernel_w(), l.stride_h()),
574                    0,
575                ),
576                ModelLayer::AvgPool2d(l) => (
577                    "AvgPool2d".to_string(),
578                    format!("{}x{} s{}", l.kernel_h(), l.kernel_w(), l.stride_h()),
579                    0,
580                ),
581                ModelLayer::GlobalAvgPool2d(_) => ("GlobalAvgPool2d".to_string(), String::new(), 0),
582                ModelLayer::Flatten(_) => ("Flatten".to_string(), String::new(), 0),
583                ModelLayer::ReLU(_) => ("ReLU".to_string(), String::new(), 0),
584                ModelLayer::LeakyReLU(_) => ("LeakyReLU".to_string(), String::new(), 0),
585                ModelLayer::Sigmoid(_) => ("Sigmoid".to_string(), String::new(), 0),
586                ModelLayer::Tanh(_) => ("Tanh".to_string(), String::new(), 0),
587                ModelLayer::Softmax(_) => ("Softmax".to_string(), String::new(), 0),
588                ModelLayer::Dropout(d) => ("Dropout".to_string(), format!("p={:.2}", d.rate()), 0),
589                ModelLayer::Embedding(e) => {
590                    let p = e.num_embeddings() * e.embedding_dim();
591                    (
592                        "Embedding".to_string(),
593                        format!("{}x{}", e.num_embeddings(), e.embedding_dim()),
594                        p,
595                    )
596                }
597                ModelLayer::LayerNorm(_) => ("LayerNorm".to_string(), String::new(), 0),
598                ModelLayer::GroupNorm(_) => ("GroupNorm".to_string(), String::new(), 0),
599                ModelLayer::DepthwiseConv2d(_) => ("DepthwiseConv2d".to_string(), String::new(), 0),
600                ModelLayer::SeparableConv2d(_) => ("SeparableConv2d".to_string(), String::new(), 0),
601                ModelLayer::DeformableConv2d(l) => {
602                    let p = l.weight.data().len() + l.offset_weight.data().len();
603                    (
604                        "DeformableConv2d".to_string(),
605                        format!(
606                            "{}→{} {}x{}",
607                            l.in_channels(),
608                            l.out_channels(),
609                            l.kernel_h,
610                            l.kernel_w
611                        ),
612                        p,
613                    )
614                }
615                ModelLayer::LoraLinear(l) => (
616                    "LoraLinear".to_string(),
617                    format!("{}→{} r={}", l.in_features, l.out_features, l.rank),
618                    l.in_features * l.rank + l.rank * l.out_features,
619                ),
620                ModelLayer::Conv1d(l) => {
621                    let p = l.kernel().data().len();
622                    ("Conv1d".to_string(), format!("k={}", l.kernel_size()), p)
623                }
624                ModelLayer::Conv3d(l) => {
625                    let p = l.weight().data().len();
626                    (
627                        "Conv3d".to_string(),
628                        format!("{}→{}", l.in_channels(), l.out_channels()),
629                        p,
630                    )
631                }
632                ModelLayer::ConvTranspose2d(l) => {
633                    let p = l.kernel().data().len();
634                    (
635                        "ConvTranspose2d".to_string(),
636                        format!("s={}", l.stride()),
637                        p,
638                    )
639                }
640                ModelLayer::AdaptiveAvgPool2d(l) => (
641                    "AdaptiveAvgPool2d".to_string(),
642                    format!("{}x{}", l.output_h(), l.output_w()),
643                    0,
644                ),
645                ModelLayer::AdaptiveMaxPool2d(l) => (
646                    "AdaptiveMaxPool2d".to_string(),
647                    format!("{}x{}", l.output_h(), l.output_w()),
648                    0,
649                ),
650                ModelLayer::InstanceNorm(_) => ("InstanceNorm".to_string(), String::new(), 0),
651                ModelLayer::PixelShuffle(l) => (
652                    "PixelShuffle".to_string(),
653                    format!("r={}", l.upscale_factor()),
654                    0,
655                ),
656                ModelLayer::Upsample(l) => {
657                    ("Upsample".to_string(), format!("{}x", l.scale_factor()), 0)
658                }
659                ModelLayer::GELU(_) => ("GELU".to_string(), String::new(), 0),
660                ModelLayer::SiLU(_) => ("SiLU".to_string(), String::new(), 0),
661                ModelLayer::Mish(_) => ("Mish".to_string(), String::new(), 0),
662                ModelLayer::PReLU(l) => (
663                    "PReLU".to_string(),
664                    format!("channels={}", l.alpha().len()),
665                    l.alpha().len(),
666                ),
667                ModelLayer::ResidualBlock(r) => (
668                    "ResidualBlock".to_string(),
669                    format!("{} layers", r.layers().len()),
670                    0,
671                ),
672                ModelLayer::Rnn(l) => {
673                    let p = l.input_size * l.hidden_size
674                        + l.hidden_size * l.hidden_size
675                        + l.hidden_size;
676                    (
677                        "Rnn".to_string(),
678                        format!("{}→{}", l.input_size, l.hidden_size),
679                        p,
680                    )
681                }
682                ModelLayer::Lstm(l) => {
683                    let h4 = 4 * l.hidden_size;
684                    let p = l.input_size * h4 + l.hidden_size * h4 + h4;
685                    (
686                        "Lstm".to_string(),
687                        format!("{}→{}", l.input_size, l.hidden_size),
688                        p,
689                    )
690                }
691                ModelLayer::Gru(l) => {
692                    let h3 = 3 * l.hidden_size;
693                    let p = l.input_size * h3 + l.hidden_size * h3 + 2 * h3;
694                    (
695                        "Gru".to_string(),
696                        format!("{}→{}", l.input_size, l.hidden_size),
697                        p,
698                    )
699                }
700                ModelLayer::MultiHeadAttention(_) => {
701                    ("MultiHeadAttention".to_string(), String::new(), 0)
702                }
703                ModelLayer::TransformerEncoder(_) => {
704                    ("TransformerEncoder".to_string(), String::new(), 0)
705                }
706                ModelLayer::FeedForward(_) => ("FeedForward".to_string(), String::new(), 0),
707            };
708
709            total_params += params;
710            let params_str = if params > 0 {
711                format_param_count(params)
712            } else {
713                "-".to_string()
714            };
715            writeln!(
716                out,
717                " {:>3}  {:<25} {:>20}  {:>8}",
718                i, name, details, params_str
719            )
720            .expect("write to String");
721        }
722
723        writeln!(out, "{sep}").expect("write to String");
724        writeln!(
725            out,
726            " Total: {} layers, {} parameters",
727            self.layers.len(),
728            format_param_count(total_params)
729        )
730        .expect("write to String");
731        writeln!(out, "{sep}").expect("write to String");
732        out
733    }
734
735    /// Returns the total number of parameters (weights + biases) across all layers.
736    pub fn num_parameters(&self) -> usize {
737        let mut total = 0usize;
738        for layer in &self.layers {
739            let params = match layer {
740                ModelLayer::Linear(l) => l.in_features() * l.out_features() + l.out_features(),
741                ModelLayer::Conv2d(l) => {
742                    l.out_channels() * l.in_channels() * l.kernel_h() * l.kernel_w()
743                        + if l.bias().is_some() {
744                            l.out_channels()
745                        } else {
746                            0
747                        }
748                }
749                ModelLayer::BatchNorm2d(l) => l.num_features() * 2,
750                ModelLayer::Embedding(e) => e.num_embeddings() * e.embedding_dim(),
751                ModelLayer::LoraLinear(l) => l.in_features * l.rank + l.rank * l.out_features,
752                ModelLayer::Conv1d(l) => l.kernel().data().len(),
753                ModelLayer::Conv3d(l) => l.weight().data().len(),
754                ModelLayer::ConvTranspose2d(l) => l.kernel().data().len(),
755                ModelLayer::DepthwiseConv2d(l) => {
756                    l.weight().data().len() + l.bias().map_or(0, |b| b.data().len())
757                }
758                ModelLayer::SeparableConv2d(l) => {
759                    l.depthwise().weight().data().len()
760                        + l.depthwise().bias().map_or(0, |b| b.data().len())
761                        + l.pointwise().weight().data().len()
762                        + l.pointwise().bias().map_or(0, |b| b.data().len())
763                }
764                ModelLayer::DeformableConv2d(l) => {
765                    l.weight.data().len()
766                        + l.offset_weight.data().len()
767                        + l.bias.as_ref().map_or(0, |b| b.data().len())
768                }
769                ModelLayer::InstanceNorm(_) => 0,
770                ModelLayer::LayerNorm(_)
771                | ModelLayer::GroupNorm(_)
772                | ModelLayer::MaxPool2d(_)
773                | ModelLayer::AvgPool2d(_)
774                | ModelLayer::GlobalAvgPool2d(_)
775                | ModelLayer::Flatten(_)
776                | ModelLayer::ReLU(_)
777                | ModelLayer::LeakyReLU(_)
778                | ModelLayer::Sigmoid(_)
779                | ModelLayer::Tanh(_)
780                | ModelLayer::Softmax(_)
781                | ModelLayer::Dropout(_)
782                | ModelLayer::AdaptiveAvgPool2d(_)
783                | ModelLayer::AdaptiveMaxPool2d(_)
784                | ModelLayer::PixelShuffle(_)
785                | ModelLayer::Upsample(_)
786                | ModelLayer::GELU(_)
787                | ModelLayer::SiLU(_)
788                | ModelLayer::Mish(_) => 0,
789                ModelLayer::PReLU(l) => l.alpha().len(),
790                ModelLayer::ResidualBlock(_) => 0,
791                ModelLayer::Rnn(l) => {
792                    l.input_size * l.hidden_size + l.hidden_size * l.hidden_size + l.hidden_size
793                }
794                ModelLayer::Lstm(l) => {
795                    let h4 = 4 * l.hidden_size;
796                    l.input_size * h4 + l.hidden_size * h4 + h4
797                }
798                ModelLayer::Gru(l) => {
799                    let h3 = 3 * l.hidden_size;
800                    l.input_size * h3 + l.hidden_size * h3 + 2 * h3
801                }
802                ModelLayer::MultiHeadAttention(_)
803                | ModelLayer::TransformerEncoder(_)
804                | ModelLayer::FeedForward(_) => 0,
805            };
806            total += params;
807        }
808        total
809    }
810
811    /// Freeze the layer at `index` so it is excluded from `trainable_parameters`.
812    pub fn freeze_layer(&mut self, index: usize) -> Result<(), ModelError> {
813        if index >= self.layers.len() {
814            return Err(ModelError::InvalidLayerIndex {
815                index,
816                count: self.layers.len(),
817            });
818        }
819        self.frozen[index] = true;
820        Ok(())
821    }
822
823    /// Unfreeze the layer at `index` so it is included in `trainable_parameters` again.
824    pub fn unfreeze_layer(&mut self, index: usize) -> Result<(), ModelError> {
825        if index >= self.layers.len() {
826            return Err(ModelError::InvalidLayerIndex {
827                index,
828                count: self.layers.len(),
829            });
830        }
831        self.frozen[index] = false;
832        Ok(())
833    }
834
835    /// Returns a slice of booleans indicating which layers are frozen.
836    pub fn frozen_mask(&self) -> &[bool] {
837        &self.frozen
838    }
839
840    /// Returns the total number of parameters in non-frozen layers.
841    pub fn trainable_parameters(&self) -> usize {
842        let mut total = 0usize;
843        for (i, layer) in self.layers.iter().enumerate() {
844            if self.frozen[i] {
845                continue;
846            }
847            let params = match layer {
848                ModelLayer::Linear(l) => l.in_features() * l.out_features() + l.out_features(),
849                ModelLayer::Conv2d(l) => {
850                    l.out_channels() * l.in_channels() * l.kernel_h() * l.kernel_w()
851                        + if l.bias().is_some() {
852                            l.out_channels()
853                        } else {
854                            0
855                        }
856                }
857                ModelLayer::BatchNorm2d(l) => l.num_features() * 2,
858                ModelLayer::Embedding(e) => e.num_embeddings() * e.embedding_dim(),
859                ModelLayer::LoraLinear(l) => l.in_features * l.rank + l.rank * l.out_features,
860                ModelLayer::Conv1d(l) => l.kernel().data().len(),
861                ModelLayer::Conv3d(l) => l.weight().data().len(),
862                ModelLayer::ConvTranspose2d(l) => l.kernel().data().len(),
863                ModelLayer::DepthwiseConv2d(l) => {
864                    l.weight().data().len() + l.bias().map_or(0, |b| b.data().len())
865                }
866                ModelLayer::SeparableConv2d(l) => {
867                    l.depthwise().weight().data().len()
868                        + l.depthwise().bias().map_or(0, |b| b.data().len())
869                        + l.pointwise().weight().data().len()
870                        + l.pointwise().bias().map_or(0, |b| b.data().len())
871                }
872                ModelLayer::DeformableConv2d(l) => {
873                    l.weight.data().len()
874                        + l.offset_weight.data().len()
875                        + l.bias.as_ref().map_or(0, |b| b.data().len())
876                }
877                ModelLayer::InstanceNorm(_) => 0,
878                ModelLayer::LayerNorm(_)
879                | ModelLayer::GroupNorm(_)
880                | ModelLayer::MaxPool2d(_)
881                | ModelLayer::AvgPool2d(_)
882                | ModelLayer::GlobalAvgPool2d(_)
883                | ModelLayer::Flatten(_)
884                | ModelLayer::ReLU(_)
885                | ModelLayer::LeakyReLU(_)
886                | ModelLayer::Sigmoid(_)
887                | ModelLayer::Tanh(_)
888                | ModelLayer::Softmax(_)
889                | ModelLayer::Dropout(_)
890                | ModelLayer::AdaptiveAvgPool2d(_)
891                | ModelLayer::AdaptiveMaxPool2d(_)
892                | ModelLayer::PixelShuffle(_)
893                | ModelLayer::Upsample(_)
894                | ModelLayer::GELU(_)
895                | ModelLayer::SiLU(_)
896                | ModelLayer::Mish(_) => 0,
897                ModelLayer::PReLU(l) => l.alpha().len(),
898                ModelLayer::ResidualBlock(_) => 0,
899                ModelLayer::Rnn(l) => {
900                    l.input_size * l.hidden_size + l.hidden_size * l.hidden_size + l.hidden_size
901                }
902                ModelLayer::Lstm(l) => {
903                    let h4 = 4 * l.hidden_size;
904                    l.input_size * h4 + l.hidden_size * h4 + h4
905                }
906                ModelLayer::Gru(l) => {
907                    let h3 = 3 * l.hidden_size;
908                    l.input_size * h3 + l.hidden_size * h3 + 2 * h3
909                }
910                ModelLayer::MultiHeadAttention(_)
911                | ModelLayer::TransformerEncoder(_)
912                | ModelLayer::FeedForward(_) => 0,
913            };
914            total += params;
915        }
916        total
917    }
918
919    /// Returns named parameter tensors from all layers.
920    ///
921    /// For layers with graph-registered weights (Linear, Embedding, LayerNorm,
922    /// GroupNorm, LoraLinear), the tensors are retrieved from the graph.
923    /// For layers that own their tensors directly (Conv2d, BatchNorm2d,
924    /// DepthwiseConv2d, SeparableConv2d, Conv1d, ConvTranspose2d), the tensors
925    /// are accessed via the layer's accessor methods.
926    ///
927    /// Names follow the pattern `"{type}{index}_{param}"`, e.g. `"linear0_weight"`.
928    pub fn named_parameters<'a>(
929        &'a self,
930        graph: &'a Graph,
931    ) -> Result<Vec<(String, &'a Tensor)>, ModelError> {
932        let mut result = Vec::new();
933        let mut type_counts = std::collections::HashMap::<&str, usize>::new();
934
935        for layer in &self.layers {
936            match layer {
937                ModelLayer::Linear(l) => {
938                    let idx = type_counts.entry("linear").or_insert(0);
939                    let i = *idx;
940                    *idx += 1;
941                    result.push((
942                        format!("linear{i}_weight"),
943                        graph.value(l.weight_node().expect("linear layer has weight node"))?,
944                    ));
945                    result.push((
946                        format!("linear{i}_bias"),
947                        graph.value(l.bias_node().expect("linear layer has bias node"))?,
948                    ));
949                }
950                ModelLayer::Conv2d(l) => {
951                    let idx = type_counts.entry("conv2d").or_insert(0);
952                    let i = *idx;
953                    *idx += 1;
954                    result.push((format!("conv2d{i}_weight"), l.weight()));
955                    if let Some(b) = l.bias() {
956                        result.push((format!("conv2d{i}_bias"), b));
957                    }
958                }
959                ModelLayer::BatchNorm2d(l) => {
960                    let idx = type_counts.entry("batchnorm2d").or_insert(0);
961                    let i = *idx;
962                    *idx += 1;
963                    result.push((format!("batchnorm2d{i}_gamma"), l.gamma()));
964                    result.push((format!("batchnorm2d{i}_beta"), l.beta()));
965                }
966                ModelLayer::Embedding(e) => {
967                    let idx = type_counts.entry("embedding").or_insert(0);
968                    let i = *idx;
969                    *idx += 1;
970                    result.push((
971                        format!("embedding{i}_weight"),
972                        graph.value(e.weight_node())?,
973                    ));
974                }
975                ModelLayer::LayerNorm(l) => {
976                    let idx = type_counts.entry("layernorm").or_insert(0);
977                    let i = *idx;
978                    *idx += 1;
979                    result.push((format!("layernorm{i}_gamma"), graph.value(l.gamma_node())?));
980                    result.push((format!("layernorm{i}_beta"), graph.value(l.beta_node())?));
981                }
982                ModelLayer::GroupNorm(l) => {
983                    let idx = type_counts.entry("groupnorm").or_insert(0);
984                    let i = *idx;
985                    *idx += 1;
986                    result.push((format!("groupnorm{i}_gamma"), graph.value(l.gamma_node())?));
987                    result.push((format!("groupnorm{i}_beta"), graph.value(l.beta_node())?));
988                }
989                ModelLayer::LoraLinear(l) => {
990                    let idx = type_counts.entry("loralinear").or_insert(0);
991                    let i = *idx;
992                    *idx += 1;
993                    result.push((
994                        format!("loralinear{i}_frozen_weight"),
995                        graph.value(l.frozen_weight)?,
996                    ));
997                    result.push((format!("loralinear{i}_lora_a"), graph.value(l.lora_a)?));
998                    result.push((format!("loralinear{i}_lora_b"), graph.value(l.lora_b)?));
999                    if let Some(bias_node) = l.bias {
1000                        result.push((format!("loralinear{i}_bias"), graph.value(bias_node)?));
1001                    }
1002                }
1003                ModelLayer::DepthwiseConv2d(l) => {
1004                    let idx = type_counts.entry("depthwiseconv2d").or_insert(0);
1005                    let i = *idx;
1006                    *idx += 1;
1007                    result.push((format!("depthwiseconv2d{i}_weight"), l.weight()));
1008                    if let Some(b) = l.bias() {
1009                        result.push((format!("depthwiseconv2d{i}_bias"), b));
1010                    }
1011                }
1012                ModelLayer::SeparableConv2d(l) => {
1013                    let idx = type_counts.entry("separableconv2d").or_insert(0);
1014                    let i = *idx;
1015                    *idx += 1;
1016                    result.push((
1017                        format!("separableconv2d{i}_depthwise_weight"),
1018                        l.depthwise().weight(),
1019                    ));
1020                    if let Some(b) = l.depthwise().bias() {
1021                        result.push((format!("separableconv2d{i}_depthwise_bias"), b));
1022                    }
1023                    result.push((
1024                        format!("separableconv2d{i}_pointwise_weight"),
1025                        l.pointwise().weight(),
1026                    ));
1027                    if let Some(b) = l.pointwise().bias() {
1028                        result.push((format!("separableconv2d{i}_pointwise_bias"), b));
1029                    }
1030                }
1031                ModelLayer::Conv1d(l) => {
1032                    let idx = type_counts.entry("conv1d").or_insert(0);
1033                    let i = *idx;
1034                    *idx += 1;
1035                    result.push((format!("conv1d{i}_weight"), l.kernel()));
1036                }
1037                ModelLayer::Conv3d(l) => {
1038                    let idx = type_counts.entry("conv3d").or_insert(0);
1039                    let i = *idx;
1040                    *idx += 1;
1041                    result.push((format!("conv3d{i}_weight"), l.weight()));
1042                }
1043                ModelLayer::ConvTranspose2d(l) => {
1044                    let idx = type_counts.entry("convtranspose2d").or_insert(0);
1045                    let i = *idx;
1046                    *idx += 1;
1047                    result.push((format!("convtranspose2d{i}_weight"), l.kernel()));
1048                }
1049                // Layers without parameters
1050                ModelLayer::ReLU(_)
1051                | ModelLayer::LeakyReLU(_)
1052                | ModelLayer::Sigmoid(_)
1053                | ModelLayer::Tanh(_)
1054                | ModelLayer::Softmax(_)
1055                | ModelLayer::Dropout(_)
1056                | ModelLayer::MaxPool2d(_)
1057                | ModelLayer::AvgPool2d(_)
1058                | ModelLayer::GlobalAvgPool2d(_)
1059                | ModelLayer::Flatten(_)
1060                | ModelLayer::AdaptiveAvgPool2d(_)
1061                | ModelLayer::AdaptiveMaxPool2d(_)
1062                | ModelLayer::InstanceNorm(_)
1063                | ModelLayer::PixelShuffle(_)
1064                | ModelLayer::Upsample(_)
1065                | ModelLayer::GELU(_)
1066                | ModelLayer::SiLU(_)
1067                | ModelLayer::Mish(_)
1068                | ModelLayer::PReLU(_)
1069                | ModelLayer::ResidualBlock(_)
1070                | ModelLayer::Rnn(_)
1071                | ModelLayer::Lstm(_)
1072                | ModelLayer::Gru(_)
1073                | ModelLayer::MultiHeadAttention(_)
1074                | ModelLayer::TransformerEncoder(_)
1075                | ModelLayer::FeedForward(_)
1076                | ModelLayer::DeformableConv2d(_) => {}
1077            }
1078        }
1079        Ok(result)
1080    }
1081
1082    pub fn layers(&self) -> &[ModelLayer] {
1083        &self.layers
1084    }
1085
1086    pub fn layers_mut(&mut self) -> &mut [ModelLayer] {
1087        &mut self.layers
1088    }
1089
1090    /// Adds a residual block wrapping the given layers.
1091    pub fn add_residual_block(&mut self, layers: Vec<ModelLayer>) {
1092        self.layers
1093            .push(ModelLayer::ResidualBlock(ResidualBlock::new(layers)));
1094        self.frozen.push(false);
1095    }
1096
1097    pub fn add_rnn(&mut self, input_size: usize, hidden_size: usize, seed: u64) {
1098        self.layers.push(ModelLayer::Rnn(RnnLayer::new(
1099            input_size,
1100            hidden_size,
1101            seed,
1102        )));
1103        self.frozen.push(false);
1104    }
1105
1106    pub fn add_lstm(&mut self, input_size: usize, hidden_size: usize, seed: u64) {
1107        self.layers.push(ModelLayer::Lstm(LstmLayer::new(
1108            input_size,
1109            hidden_size,
1110            seed,
1111        )));
1112        self.frozen.push(false);
1113    }
1114
1115    pub fn add_gru(&mut self, input_size: usize, hidden_size: usize, seed: u64) {
1116        self.layers.push(ModelLayer::Gru(GruLayer::new(
1117            input_size,
1118            hidden_size,
1119            seed,
1120        )));
1121        self.frozen.push(false);
1122    }
1123
1124    pub fn add_multi_head_attention(&mut self, d_model: usize, num_heads: usize, seed: u64) {
1125        self.layers.push(ModelLayer::MultiHeadAttention(
1126            MultiHeadAttentionLayer::new(d_model, num_heads, seed),
1127        ));
1128        self.frozen.push(false);
1129    }
1130
1131    pub fn add_transformer_encoder(
1132        &mut self,
1133        d_model: usize,
1134        num_heads: usize,
1135        d_ff: usize,
1136        seed: u64,
1137    ) {
1138        self.layers.push(ModelLayer::TransformerEncoder(
1139            TransformerEncoderLayer::new(d_model, num_heads, d_ff, seed),
1140        ));
1141        self.frozen.push(false);
1142    }
1143
1144    pub fn add_feed_forward(&mut self, d_model: usize, d_ff: usize, seed: u64) {
1145        self.layers
1146            .push(ModelLayer::FeedForward(FeedForwardLayer::new(
1147                d_model, d_ff, seed,
1148            )));
1149        self.frozen.push(false);
1150    }
1151
1152    /// Push a pre-built layer directly (used by fusion for inference-only layers).
1153    pub fn push_raw_layer(&mut self, layer: ModelLayer) {
1154        self.layers.push(layer);
1155        self.frozen.push(false);
1156    }
1157
1158    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
1159        let mut current = input;
1160        for layer in &self.layers {
1161            current = layer.forward(graph, current)?;
1162        }
1163        Ok(current)
1164    }
1165
1166    /// Pure-tensor inference forward pass (no autograd graph).
1167    ///
1168    /// Supports Conv2d, BatchNorm2d, MaxPool2d, AvgPool2d, Flatten, Softmax,
1169    /// and simple activation layers (ReLU, Sigmoid, Tanh, LeakyReLU).
1170    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
1171        let mut current = input.clone();
1172        for layer in &self.layers {
1173            current = match layer {
1174                ModelLayer::Conv2d(l) => l.forward_inference(&current)?,
1175                ModelLayer::BatchNorm2d(l) => l.forward_inference(&current)?,
1176                ModelLayer::MaxPool2d(l) => l.forward_inference(&current)?,
1177                ModelLayer::AvgPool2d(l) => l.forward_inference(&current)?,
1178                ModelLayer::Flatten(l) => l.forward_inference(&current)?,
1179                ModelLayer::Softmax(l) => l.forward_inference(&current)?,
1180                ModelLayer::ReLU(_) => {
1181                    let out_data: Vec<f32> = current.data().iter().map(|&v| v.max(0.0)).collect();
1182                    Tensor::from_vec(current.shape().to_vec(), out_data)?
1183                }
1184                ModelLayer::Sigmoid(_) => {
1185                    let out_data: Vec<f32> = current
1186                        .data()
1187                        .iter()
1188                        .map(|&v| 1.0 / (1.0 + (-v).exp()))
1189                        .collect();
1190                    Tensor::from_vec(current.shape().to_vec(), out_data)?
1191                }
1192                ModelLayer::Tanh(_) => {
1193                    let out_data: Vec<f32> = current.data().iter().map(|&v| v.tanh()).collect();
1194                    Tensor::from_vec(current.shape().to_vec(), out_data)?
1195                }
1196                ModelLayer::LeakyReLU(l) => {
1197                    let slope = l.negative_slope();
1198                    let out_data: Vec<f32> = current
1199                        .data()
1200                        .iter()
1201                        .map(|&v| if v >= 0.0 { v } else { slope * v })
1202                        .collect();
1203                    Tensor::from_vec(current.shape().to_vec(), out_data)?
1204                }
1205                ModelLayer::GlobalAvgPool2d(l) => l.forward_inference(&current)?,
1206                ModelLayer::DepthwiseConv2d(l) => l.forward_inference(&current)?,
1207                ModelLayer::SeparableConv2d(l) => l.forward_inference(&current)?,
1208                ModelLayer::Dropout(d) => {
1209                    if self.training && d.rate() > 0.0 {
1210                        let scale = 1.0 / (1.0 - d.rate());
1211                        let scaled: Vec<f32> = current.data().iter().map(|&v| v * scale).collect();
1212                        Tensor::from_vec(current.shape().to_vec(), scaled)?
1213                    } else {
1214                        current
1215                    }
1216                }
1217                ModelLayer::Conv1d(l) => l.forward_inference(&current)?,
1218                ModelLayer::Conv3d(l) => l.forward_inference(&current)?,
1219                ModelLayer::ConvTranspose2d(l) => l.forward_inference(&current)?,
1220                ModelLayer::AdaptiveAvgPool2d(l) => l.forward_inference(&current)?,
1221                ModelLayer::AdaptiveMaxPool2d(l) => l.forward_inference(&current)?,
1222                ModelLayer::InstanceNorm(l) => l.forward_inference(&current)?,
1223                ModelLayer::PixelShuffle(l) => l.forward_inference(&current)?,
1224                ModelLayer::Upsample(l) => l.forward_inference(&current)?,
1225                ModelLayer::GELU(l) => l.forward_inference(&current)?,
1226                ModelLayer::SiLU(l) => l.forward_inference(&current)?,
1227                ModelLayer::Mish(l) => l.forward_inference(&current)?,
1228                ModelLayer::PReLU(l) => l.forward_inference(&current)?,
1229                ModelLayer::ResidualBlock(l) => l.forward_inference(&current)?,
1230                ModelLayer::Rnn(l) => l.forward_inference(&current)?,
1231                ModelLayer::Lstm(l) => l.forward_inference(&current)?,
1232                ModelLayer::Gru(l) => l.forward_inference(&current)?,
1233                ModelLayer::MultiHeadAttention(l) => l.forward_inference(&current)?,
1234                ModelLayer::TransformerEncoder(l) => l.forward_inference(&current)?,
1235                ModelLayer::FeedForward(l) => l.forward_inference(&current)?,
1236                ModelLayer::DeformableConv2d(l) => l.forward_inference(&current)?,
1237                ModelLayer::Linear(l) => l.forward_inference(&current)?,
1238                ModelLayer::Embedding(_)
1239                | ModelLayer::LayerNorm(_)
1240                | ModelLayer::GroupNorm(_)
1241                | ModelLayer::LoraLinear(_) => return Err(ModelError::GraphOnlyLayer),
1242            };
1243        }
1244        Ok(current)
1245    }
1246
1247    /// Registers CNN layer parameters (Conv2d weight/bias, BatchNorm2d gamma/beta, etc.)
1248    /// as graph variables for autograd training.
1249    ///
1250    /// Layers whose parameters are already registered (i.e. `weight_node().is_some()`)
1251    /// are skipped, so this method is safe to call multiple times.
1252    pub fn register_cnn_params(&mut self, graph: &mut Graph) {
1253        for layer in &mut self.layers {
1254            match layer {
1255                ModelLayer::Conv2d(l) if l.weight_node().is_none() => l.register_params(graph),
1256                ModelLayer::BatchNorm2d(l) if l.gamma_node().is_none() => l.register_params(graph),
1257                ModelLayer::DepthwiseConv2d(l) if l.weight_node().is_none() => {
1258                    l.register_params(graph)
1259                }
1260                ModelLayer::SeparableConv2d(l) if l.depthwise().weight_node().is_none() => {
1261                    l.register_params(graph)
1262                }
1263                ModelLayer::Conv1d(l) if l.weight_node().is_none() => l.register_params(graph),
1264                ModelLayer::Conv3d(l) if l.weight_node().is_none() => l.register_params(graph),
1265                ModelLayer::MultiHeadAttention(l) if l.w_q_node().is_none() => {
1266                    l.register_params(graph)
1267                }
1268                ModelLayer::FeedForward(l) if l.w1_node().is_none() => l.register_params(graph),
1269                ModelLayer::TransformerEncoder(l) if l.ln1_gamma_node().is_none() => {
1270                    l.register_params(graph)
1271                }
1272                ModelLayer::Rnn(l) if l.w_ih_node().is_none() => l.register_params(graph),
1273                ModelLayer::Lstm(l) if l.w_ih_node().is_none() => l.register_params(graph),
1274                ModelLayer::Gru(l) if l.w_ih_node().is_none() => l.register_params(graph),
1275                ModelLayer::DeformableConv2d(l) if l.weight_node().is_none() => {
1276                    l.register_params(graph)
1277                }
1278                ModelLayer::ConvTranspose2d(l) if l.weight_node().is_none() => {
1279                    l.register_params(graph)
1280                }
1281                ModelLayer::InstanceNorm(l) if l.gamma_node().is_none() => l.register_params(graph),
1282                ModelLayer::PReLU(l) if l.alpha_node().is_none() => l.register_params(graph),
1283                _ => {}
1284            }
1285        }
1286        self.persistent_node_count = graph.node_count();
1287    }
1288
1289    /// Synchronizes CNN layer owned tensors from the graph (e.g. after optimizer step).
1290    pub fn sync_cnn_from_graph(&mut self, graph: &Graph) -> Result<(), ModelError> {
1291        for layer in &mut self.layers {
1292            match layer {
1293                ModelLayer::Conv2d(l) => l.sync_from_graph(graph)?,
1294                ModelLayer::BatchNorm2d(l) => l.sync_from_graph(graph)?,
1295                ModelLayer::DepthwiseConv2d(l) => l.sync_from_graph(graph)?,
1296                ModelLayer::SeparableConv2d(l) => l.sync_from_graph(graph)?,
1297                _ => {}
1298            }
1299        }
1300        Ok(())
1301    }
1302
1303    pub fn trainable_nodes(&self) -> Vec<NodeId> {
1304        let mut out = Vec::new();
1305        for layer in &self.layers {
1306            match layer {
1307                ModelLayer::Linear(linear) => {
1308                    out.extend(linear.trainable_nodes());
1309                }
1310                ModelLayer::Conv2d(conv) => {
1311                    if let Some(w) = conv.weight_node() {
1312                        out.push(w);
1313                    }
1314                    if let Some(b) = conv.bias_node() {
1315                        out.push(b);
1316                    }
1317                }
1318                ModelLayer::BatchNorm2d(bn) => {
1319                    if let Some(g) = bn.gamma_node() {
1320                        out.push(g);
1321                    }
1322                    if let Some(b) = bn.beta_node() {
1323                        out.push(b);
1324                    }
1325                }
1326                ModelLayer::DepthwiseConv2d(dw) => {
1327                    if let Some(w) = dw.weight_node() {
1328                        out.push(w);
1329                    }
1330                    if let Some(b) = dw.bias_node() {
1331                        out.push(b);
1332                    }
1333                }
1334                ModelLayer::SeparableConv2d(sep) => {
1335                    if let Some(w) = sep.depthwise().weight_node() {
1336                        out.push(w);
1337                    }
1338                    if let Some(b) = sep.depthwise().bias_node() {
1339                        out.push(b);
1340                    }
1341                    if let Some(w) = sep.pointwise().weight_node() {
1342                        out.push(w);
1343                    }
1344                    if let Some(b) = sep.pointwise().bias_node() {
1345                        out.push(b);
1346                    }
1347                }
1348                ModelLayer::LoraLinear(lora) => {
1349                    out.extend(lora.trainable_params());
1350                }
1351                _ => {}
1352            }
1353        }
1354        out
1355    }
1356
1357    pub fn persistent_node_count(&self) -> usize {
1358        self.persistent_node_count
1359    }
1360
1361    pub fn checkpoint(&self, graph: &Graph) -> Result<SequentialCheckpoint, ModelError> {
1362        let mut layers = Vec::with_capacity(self.layers.len());
1363        for layer in &self.layers {
1364            match layer {
1365                ModelLayer::Linear(linear) => {
1366                    let weight = graph
1367                        .value(linear.weight_node().expect("linear layer has weight node"))?
1368                        .clone();
1369                    let bias = graph
1370                        .value(linear.bias_node().expect("linear layer has bias node"))?
1371                        .clone();
1372                    layers.push(LayerCheckpoint::Linear {
1373                        in_features: linear.in_features(),
1374                        out_features: linear.out_features(),
1375                        weight: TensorSnapshot::from_tensor(&weight),
1376                        bias: TensorSnapshot::from_tensor(&bias),
1377                    });
1378                }
1379                ModelLayer::ReLU(_) => layers.push(LayerCheckpoint::ReLU),
1380                ModelLayer::LeakyReLU(layer) => layers.push(LayerCheckpoint::LeakyReLU {
1381                    negative_slope: layer.negative_slope(),
1382                }),
1383                ModelLayer::Sigmoid(_) => layers.push(LayerCheckpoint::Sigmoid),
1384                ModelLayer::Tanh(_) => layers.push(LayerCheckpoint::Tanh),
1385                ModelLayer::Dropout(layer) => {
1386                    layers.push(LayerCheckpoint::Dropout { rate: layer.rate() })
1387                }
1388                ModelLayer::Conv2d(layer) => layers.push(LayerCheckpoint::Conv2d {
1389                    in_channels: layer.in_channels(),
1390                    out_channels: layer.out_channels(),
1391                    kernel_h: layer.kernel_h(),
1392                    kernel_w: layer.kernel_w(),
1393                    stride_h: layer.stride_h(),
1394                    stride_w: layer.stride_w(),
1395                    weight: TensorSnapshot::from_tensor(layer.weight()),
1396                    bias: layer.bias().map(TensorSnapshot::from_tensor),
1397                }),
1398                ModelLayer::BatchNorm2d(layer) => layers.push(LayerCheckpoint::BatchNorm2d {
1399                    num_features: layer.num_features(),
1400                    epsilon: layer.epsilon(),
1401                    gamma: TensorSnapshot::from_tensor(layer.gamma()),
1402                    beta: TensorSnapshot::from_tensor(layer.beta()),
1403                    running_mean: TensorSnapshot::from_tensor(layer.running_mean()),
1404                    running_var: TensorSnapshot::from_tensor(layer.running_var()),
1405                }),
1406                ModelLayer::MaxPool2d(layer) => layers.push(LayerCheckpoint::MaxPool2d {
1407                    kernel_h: layer.kernel_h(),
1408                    kernel_w: layer.kernel_w(),
1409                    stride_h: layer.stride_h(),
1410                    stride_w: layer.stride_w(),
1411                }),
1412                ModelLayer::AvgPool2d(layer) => layers.push(LayerCheckpoint::AvgPool2d {
1413                    kernel_h: layer.kernel_h(),
1414                    kernel_w: layer.kernel_w(),
1415                    stride_h: layer.stride_h(),
1416                    stride_w: layer.stride_w(),
1417                }),
1418                ModelLayer::Flatten(_) => layers.push(LayerCheckpoint::Flatten),
1419                ModelLayer::GlobalAvgPool2d(_) => layers.push(LayerCheckpoint::GlobalAvgPool2d),
1420                ModelLayer::Softmax(_) => layers.push(LayerCheckpoint::Softmax),
1421                ModelLayer::Embedding(layer) => {
1422                    let w = graph.value(layer.weight_node())?;
1423                    layers.push(LayerCheckpoint::Embedding {
1424                        num_embeddings: layer.num_embeddings(),
1425                        embedding_dim: layer.embedding_dim(),
1426                        weight: TensorSnapshot::from_tensor(w),
1427                    });
1428                }
1429                ModelLayer::LayerNorm(layer) => {
1430                    let g = graph.value(layer.gamma_node())?;
1431                    let b = graph.value(layer.beta_node())?;
1432                    layers.push(LayerCheckpoint::LayerNorm {
1433                        normalized_shape: layer.normalized_shape(),
1434                        eps: 1e-5,
1435                        gamma: TensorSnapshot::from_tensor(g),
1436                        beta: TensorSnapshot::from_tensor(b),
1437                    });
1438                }
1439                ModelLayer::GroupNorm(layer) => {
1440                    let g = graph.value(layer.gamma_node())?;
1441                    let b = graph.value(layer.beta_node())?;
1442                    layers.push(LayerCheckpoint::GroupNorm {
1443                        num_groups: layer.num_groups(),
1444                        num_channels: layer.num_channels(),
1445                        eps: 1e-5,
1446                        gamma: TensorSnapshot::from_tensor(g),
1447                        beta: TensorSnapshot::from_tensor(b),
1448                    });
1449                }
1450                ModelLayer::DepthwiseConv2d(layer) => {
1451                    layers.push(LayerCheckpoint::DepthwiseConv2d {
1452                        channels: layer.channels(),
1453                        kernel_h: layer.kernel_h(),
1454                        kernel_w: layer.kernel_w(),
1455                        stride_h: layer.stride_h(),
1456                        stride_w: layer.stride_w(),
1457                        weight: TensorSnapshot::from_tensor(layer.weight()),
1458                        bias: layer.bias().map(TensorSnapshot::from_tensor),
1459                    });
1460                }
1461                ModelLayer::SeparableConv2d(layer) => {
1462                    layers.push(LayerCheckpoint::SeparableConv2d {
1463                        in_channels: layer.in_channels(),
1464                        out_channels: layer.out_channels(),
1465                        kernel_h: layer.kernel_h(),
1466                        kernel_w: layer.kernel_w(),
1467                        stride_h: layer.stride_h(),
1468                        stride_w: layer.stride_w(),
1469                        depthwise_weight: TensorSnapshot::from_tensor(layer.depthwise().weight()),
1470                        pointwise_weight: TensorSnapshot::from_tensor(layer.pointwise().weight()),
1471                        bias: layer.pointwise().bias().map(TensorSnapshot::from_tensor),
1472                    });
1473                }
1474                ModelLayer::LoraLinear(lora) => {
1475                    // Merge LoRA weights and checkpoint as a regular Linear layer.
1476                    let merged_weight = lora.merge(graph)?;
1477                    let bias = if let Some(bias_node) = lora.bias {
1478                        graph.value(bias_node)?.clone()
1479                    } else {
1480                        Tensor::zeros(vec![lora.out_features])?
1481                    };
1482                    layers.push(LayerCheckpoint::Linear {
1483                        in_features: lora.in_features,
1484                        out_features: lora.out_features,
1485                        weight: TensorSnapshot::from_tensor(&merged_weight),
1486                        bias: TensorSnapshot::from_tensor(&bias),
1487                    });
1488                }
1489                // Inference-only layers have no graph-registered weights to checkpoint.
1490                ModelLayer::Conv1d(_)
1491                | ModelLayer::Conv3d(_)
1492                | ModelLayer::ConvTranspose2d(_)
1493                | ModelLayer::AdaptiveAvgPool2d(_)
1494                | ModelLayer::AdaptiveMaxPool2d(_)
1495                | ModelLayer::InstanceNorm(_)
1496                | ModelLayer::PixelShuffle(_)
1497                | ModelLayer::Upsample(_)
1498                | ModelLayer::GELU(_)
1499                | ModelLayer::SiLU(_)
1500                | ModelLayer::Mish(_)
1501                | ModelLayer::PReLU(_)
1502                | ModelLayer::ResidualBlock(_)
1503                | ModelLayer::Rnn(_)
1504                | ModelLayer::Lstm(_)
1505                | ModelLayer::Gru(_)
1506                | ModelLayer::MultiHeadAttention(_)
1507                | ModelLayer::TransformerEncoder(_)
1508                | ModelLayer::FeedForward(_)
1509                | ModelLayer::DeformableConv2d(_) => {
1510                    return Err(ModelError::InferenceOnlyLayer);
1511                }
1512            }
1513        }
1514        Ok(SequentialCheckpoint { layers })
1515    }
1516
1517    pub fn from_checkpoint(
1518        graph: &mut Graph,
1519        checkpoint: &SequentialCheckpoint,
1520    ) -> Result<Self, ModelError> {
1521        let mut model = Self::new(graph);
1522        for layer in &checkpoint.layers {
1523            match layer {
1524                LayerCheckpoint::Linear {
1525                    in_features,
1526                    out_features,
1527                    weight,
1528                    bias,
1529                } => {
1530                    model.add_linear(
1531                        graph,
1532                        *in_features,
1533                        *out_features,
1534                        weight.clone().into_tensor()?,
1535                        bias.clone().into_tensor()?,
1536                    )?;
1537                }
1538                LayerCheckpoint::ReLU => model.add_relu(),
1539                LayerCheckpoint::LeakyReLU { negative_slope } => {
1540                    model.add_leaky_relu(*negative_slope)?
1541                }
1542                LayerCheckpoint::Sigmoid => model.add_sigmoid(),
1543                LayerCheckpoint::Tanh => model.add_tanh(),
1544                LayerCheckpoint::Dropout { rate } => model.add_dropout(*rate)?,
1545                LayerCheckpoint::Conv2d {
1546                    in_channels,
1547                    out_channels,
1548                    kernel_h,
1549                    kernel_w,
1550                    stride_h,
1551                    stride_w,
1552                    weight,
1553                    bias,
1554                } => {
1555                    let bias_tensor = match bias {
1556                        Some(b) => Some(b.clone().into_tensor()?),
1557                        None => None,
1558                    };
1559                    model.add_conv2d(
1560                        *in_channels,
1561                        *out_channels,
1562                        *kernel_h,
1563                        *kernel_w,
1564                        *stride_h,
1565                        *stride_w,
1566                        weight.clone().into_tensor()?,
1567                        bias_tensor,
1568                    )?;
1569                }
1570                LayerCheckpoint::BatchNorm2d {
1571                    num_features,
1572                    epsilon,
1573                    gamma,
1574                    beta,
1575                    running_mean,
1576                    running_var,
1577                } => {
1578                    model.add_batch_norm2d(
1579                        *num_features,
1580                        *epsilon,
1581                        gamma.clone().into_tensor()?,
1582                        beta.clone().into_tensor()?,
1583                        running_mean.clone().into_tensor()?,
1584                        running_var.clone().into_tensor()?,
1585                    )?;
1586                }
1587                LayerCheckpoint::MaxPool2d {
1588                    kernel_h,
1589                    kernel_w,
1590                    stride_h,
1591                    stride_w,
1592                } => model.add_max_pool2d(*kernel_h, *kernel_w, *stride_h, *stride_w)?,
1593                LayerCheckpoint::AvgPool2d {
1594                    kernel_h,
1595                    kernel_w,
1596                    stride_h,
1597                    stride_w,
1598                } => model.add_avg_pool2d(*kernel_h, *kernel_w, *stride_h, *stride_w)?,
1599                LayerCheckpoint::Flatten => model.add_flatten(),
1600                LayerCheckpoint::GlobalAvgPool2d => model.add_global_avg_pool2d(),
1601                LayerCheckpoint::Softmax => model.add_softmax(),
1602                LayerCheckpoint::Embedding {
1603                    num_embeddings,
1604                    embedding_dim,
1605                    weight,
1606                } => {
1607                    model.add_embedding(
1608                        graph,
1609                        *num_embeddings,
1610                        *embedding_dim,
1611                        weight.clone().into_tensor()?,
1612                    )?;
1613                }
1614                LayerCheckpoint::LayerNorm {
1615                    normalized_shape,
1616                    eps,
1617                    gamma: _,
1618                    beta: _,
1619                } => {
1620                    model.add_layer_norm(graph, *normalized_shape, *eps)?;
1621                }
1622                LayerCheckpoint::GroupNorm {
1623                    num_groups,
1624                    num_channels,
1625                    eps,
1626                    gamma: _,
1627                    beta: _,
1628                } => {
1629                    model.add_group_norm(graph, *num_groups, *num_channels, *eps)?;
1630                }
1631                LayerCheckpoint::DepthwiseConv2d {
1632                    channels,
1633                    kernel_h,
1634                    kernel_w,
1635                    stride_h,
1636                    stride_w,
1637                    weight,
1638                    bias,
1639                } => {
1640                    let bias_tensor = match bias {
1641                        Some(b) => Some(b.clone().into_tensor()?),
1642                        None => None,
1643                    };
1644                    model.add_depthwise_conv2d(
1645                        *channels,
1646                        *kernel_h,
1647                        *kernel_w,
1648                        *stride_h,
1649                        *stride_w,
1650                        weight.clone().into_tensor()?,
1651                        bias_tensor,
1652                    )?;
1653                }
1654                LayerCheckpoint::SeparableConv2d {
1655                    in_channels,
1656                    out_channels,
1657                    kernel_h,
1658                    kernel_w,
1659                    stride_h,
1660                    stride_w,
1661                    depthwise_weight,
1662                    pointwise_weight,
1663                    bias,
1664                } => {
1665                    let bias_tensor = match bias {
1666                        Some(b) => Some(b.clone().into_tensor()?),
1667                        None => None,
1668                    };
1669                    model.add_separable_conv2d(
1670                        *in_channels,
1671                        *out_channels,
1672                        *kernel_h,
1673                        *kernel_w,
1674                        *stride_h,
1675                        *stride_w,
1676                        depthwise_weight.clone().into_tensor()?,
1677                        pointwise_weight.clone().into_tensor()?,
1678                        bias_tensor,
1679                    )?;
1680                }
1681            }
1682        }
1683        Ok(model)
1684    }
1685}
1686
1687fn format_param_count(n: usize) -> String {
1688    if n >= 1_000_000 {
1689        format!("{:.1}M", n as f64 / 1_000_000.0)
1690    } else if n >= 1_000 {
1691        format!("{:.1}K", n as f64 / 1_000.0)
1692    } else {
1693        n.to_string()
1694    }
1695}