optirs_learned/transformer/architecture/
feedforward.rs1use std::fmt::Debug;
2#[allow(dead_code)]
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::numeric::Float;
11use scirs2_core::random::{Random, Rng as SCRRng};
12
13use super::super::TransformerOptimizerConfig;
14use crate::error::{OptimError, Result};
15
16#[derive(Debug, Clone, Copy)]
18pub enum OutputTransformation {
19 Linear,
21 Tanh,
23 Sigmoid,
25 LearnedActivation,
27 ParameterScaling,
29}
30
31#[derive(Debug, Clone)]
33pub struct OutputProjectionLayer<T: Float + Debug + Send + Sync + 'static> {
34 weights: Array2<T>,
36
37 bias: Array1<T>,
39
40 transformation: OutputTransformation,
42}
43
44#[derive(Debug, Clone)]
46pub struct InputEmbedding<T: Float + Debug + Send + Sync + 'static> {
47 weights: Array2<T>,
49
50 input_dim: usize,
52
53 modeldim: usize,
55}
56
57impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> OutputProjectionLayer<T> {
58 pub fn new(input_dim: usize, output_dim: usize) -> Result<Self> {
60 let mut rng = scirs2_core::random::thread_rng();
61 let mut weights = Array2::zeros((input_dim, output_dim));
62
63 let bound = (6.0 / (input_dim + output_dim) as f64).sqrt();
65 for elem in weights.iter_mut() {
66 *elem = T::from((rng.random::<f64>() - 0.5) * 2.0 * bound).expect("unwrap failed");
67 }
68
69 let bias = Array1::zeros(output_dim);
70
71 Ok(Self {
72 weights,
73 bias,
74 transformation: OutputTransformation::Linear,
75 })
76 }
77
78 pub fn new_with_transformation(
80 input_dim: usize,
81 output_dim: usize,
82 transformation: OutputTransformation,
83 ) -> Result<Self> {
84 let mut layer = Self::new(input_dim, output_dim)?;
85 layer.transformation = transformation;
86 Ok(layer)
87 }
88
89 pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
91 let (seq_len, input_dim) = input.dim();
92 let (weight_in, weight_out) = self.weights.dim();
93
94 if input_dim != weight_in {
95 return Err(OptimError::InvalidConfig(
96 "Input dimension doesn't match weight matrix".to_string(),
97 ));
98 }
99
100 let mut output = Array2::zeros((seq_len, weight_out));
101
102 for i in 0..seq_len {
104 for j in 0..weight_out {
105 let mut sum = T::zero();
106 for k in 0..input_dim {
107 sum = sum + input[[i, k]] * self.weights[[k, j]];
108 }
109 output[[i, j]] = sum + self.bias[j];
110 }
111 }
112
113 match self.transformation {
115 OutputTransformation::Linear => {
116 }
118 OutputTransformation::Tanh => {
119 output.mapv_inplace(|x| x.tanh());
120 }
121 OutputTransformation::Sigmoid => {
122 output.mapv_inplace(|x| T::one() / (T::one() + (-x).exp()));
123 }
124 OutputTransformation::LearnedActivation => {
125 output.mapv_inplace(|x| {
127 x * scirs2_core::numeric::NumCast::from(1.1).unwrap_or_else(|| T::zero())
128 });
129 }
130 OutputTransformation::ParameterScaling => {
131 for j in 0..weight_out {
133 let scale = T::from(1.0 + 0.1 * (j as f64).sin()).expect("unwrap failed");
134 for i in 0..seq_len {
135 output[[i, j]] = output[[i, j]] * scale;
136 }
137 }
138 }
139 }
140
141 Ok(output)
142 }
143
144 pub fn transformation(&self) -> OutputTransformation {
146 self.transformation
147 }
148
149 pub fn set_transformation(&mut self, transformation: OutputTransformation) {
151 self.transformation = transformation;
152 }
153
154 pub fn weights(&self) -> &Array2<T> {
156 &self.weights
157 }
158
159 pub fn bias(&self) -> &Array1<T> {
161 &self.bias
162 }
163
164 pub fn update_parameters(&mut self, weights: Array2<T>, bias: Array1<T>) -> Result<()> {
166 let (weight_in, weight_out) = weights.dim();
167 let bias_dim = bias.len();
168
169 if weight_out != bias_dim {
170 return Err(OptimError::InvalidConfig(
171 "Weight output dimension doesn't match bias dimension".to_string(),
172 ));
173 }
174
175 if (weight_in, weight_out) == self.weights.dim() && bias_dim == self.bias.len() {
177 self.weights = weights;
178 self.bias = bias;
179 Ok(())
180 } else {
181 Err(OptimError::InvalidConfig(
182 "New parameter dimensions don't match current layer dimensions".to_string(),
183 ))
184 }
185 }
186}
187
188impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> InputEmbedding<T> {
189 pub fn new(input_dim: usize, model_dim: usize) -> Result<Self> {
191 let mut rng = scirs2_core::random::thread_rng();
192 let mut weights = Array2::zeros((input_dim, model_dim));
193
194 let bound = (6.0 / (input_dim + model_dim) as f64).sqrt();
196 for elem in weights.iter_mut() {
197 *elem = T::from((rng.random::<f64>() - 0.5) * 2.0 * bound).expect("unwrap failed");
198 }
199
200 Ok(Self {
201 weights,
202 input_dim,
203 modeldim: model_dim,
204 })
205 }
206
207 pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
209 let (seq_len, input_dim) = input.dim();
210
211 if input_dim != self.input_dim {
212 return Err(OptimError::InvalidConfig(format!(
213 "Input dimension {} doesn't match embedding input dimension {}",
214 input_dim, self.input_dim
215 )));
216 }
217
218 let mut output = Array2::zeros((seq_len, self.modeldim));
219
220 for i in 0..seq_len {
222 for j in 0..self.modeldim {
223 let mut sum = T::zero();
224 for k in 0..self.input_dim {
225 sum = sum + input[[i, k]] * self.weights[[k, j]];
226 }
227 output[[i, j]] = sum;
228 }
229 }
230
231 Ok(output)
232 }
233
234 pub fn input_dim(&self) -> usize {
236 self.input_dim
237 }
238
239 pub fn model_dim(&self) -> usize {
241 self.modeldim
242 }
243
244 pub fn weights(&self) -> &Array2<T> {
246 &self.weights
247 }
248
249 pub fn update_weights(&mut self, weights: Array2<T>) -> Result<()> {
251 let (weight_in, weight_out) = weights.dim();
252
253 if weight_in != self.input_dim || weight_out != self.modeldim {
254 return Err(OptimError::InvalidConfig(
255 "New weight dimensions don't match embedding dimensions".to_string(),
256 ));
257 }
258
259 self.weights = weights;
260 Ok(())
261 }
262}