1use serde::{Deserialize, Serialize};
2use yscv_tensor::Tensor;
3
4use crate::ModelError;
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub struct TensorSnapshot {
9 pub shape: Vec<usize>,
10 pub data: Vec<f32>,
11}
12
13impl TensorSnapshot {
14 pub fn from_tensor(tensor: &Tensor) -> Self {
15 Self {
16 shape: tensor.shape().to_vec(),
17 data: tensor.data().to_vec(),
18 }
19 }
20
21 pub fn into_tensor(self) -> Result<Tensor, ModelError> {
22 Tensor::from_vec(self.shape, self.data).map_err(Into::into)
23 }
24}
25
26#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
28#[serde(tag = "layer", content = "payload")]
29pub enum LayerCheckpoint {
30 Linear {
31 in_features: usize,
32 out_features: usize,
33 weight: TensorSnapshot,
34 bias: TensorSnapshot,
35 },
36 ReLU,
37 LeakyReLU {
38 negative_slope: f32,
39 },
40 Sigmoid,
41 Tanh,
42 Dropout {
43 rate: f32,
44 },
45 Conv2d {
46 in_channels: usize,
47 out_channels: usize,
48 kernel_h: usize,
49 kernel_w: usize,
50 stride_h: usize,
51 stride_w: usize,
52 weight: TensorSnapshot,
53 bias: Option<TensorSnapshot>,
54 },
55 BatchNorm2d {
56 num_features: usize,
57 epsilon: f32,
58 gamma: TensorSnapshot,
59 beta: TensorSnapshot,
60 running_mean: TensorSnapshot,
61 running_var: TensorSnapshot,
62 },
63 MaxPool2d {
64 kernel_h: usize,
65 kernel_w: usize,
66 stride_h: usize,
67 stride_w: usize,
68 },
69 AvgPool2d {
70 kernel_h: usize,
71 kernel_w: usize,
72 stride_h: usize,
73 stride_w: usize,
74 },
75 Flatten,
76 GlobalAvgPool2d,
77 Softmax,
78 Embedding {
79 num_embeddings: usize,
80 embedding_dim: usize,
81 weight: TensorSnapshot,
82 },
83 LayerNorm {
84 normalized_shape: usize,
85 eps: f32,
86 gamma: TensorSnapshot,
87 beta: TensorSnapshot,
88 },
89 GroupNorm {
90 num_groups: usize,
91 num_channels: usize,
92 eps: f32,
93 gamma: TensorSnapshot,
94 beta: TensorSnapshot,
95 },
96 DepthwiseConv2d {
97 channels: usize,
98 kernel_h: usize,
99 kernel_w: usize,
100 stride_h: usize,
101 stride_w: usize,
102 weight: TensorSnapshot,
103 bias: Option<TensorSnapshot>,
104 },
105 SeparableConv2d {
106 in_channels: usize,
107 out_channels: usize,
108 kernel_h: usize,
109 kernel_w: usize,
110 stride_h: usize,
111 stride_w: usize,
112 depthwise_weight: TensorSnapshot,
113 pointwise_weight: TensorSnapshot,
114 bias: Option<TensorSnapshot>,
115 },
116}
117
118#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
120pub struct SequentialCheckpoint {
121 pub layers: Vec<LayerCheckpoint>,
122}
123
124pub fn checkpoint_to_json(checkpoint: &SequentialCheckpoint) -> Result<String, ModelError> {
125 serde_json::to_string_pretty(checkpoint).map_err(|err| ModelError::CheckpointSerialization {
126 message: err.to_string(),
127 })
128}
129
130pub fn checkpoint_from_json(json: &str) -> Result<SequentialCheckpoint, ModelError> {
131 serde_json::from_str(json).map_err(|err| ModelError::CheckpointSerialization {
132 message: err.to_string(),
133 })
134}