1use super::{ActivationFunction, KerasLayer};
4use crate::error::{MLError, Result};
5use scirs2_core::ndarray::{ArrayD, IxDyn};
6
7pub struct Conv2D {
9 filters: usize,
11 kernel_size: (usize, usize),
13 strides: (usize, usize),
15 padding: String,
17 activation: Option<ActivationFunction>,
19 use_bias: bool,
21 kernel: Option<ArrayD<f64>>,
23 bias: Option<ArrayD<f64>>,
25 built: bool,
27 layer_name: Option<String>,
29}
30
31impl Conv2D {
32 pub fn new(filters: usize, kernel_size: (usize, usize)) -> Self {
34 Self {
35 filters,
36 kernel_size,
37 strides: (1, 1),
38 padding: "valid".to_string(),
39 activation: None,
40 use_bias: true,
41 kernel: None,
42 bias: None,
43 built: false,
44 layer_name: None,
45 }
46 }
47
48 pub fn strides(mut self, strides: (usize, usize)) -> Self {
50 self.strides = strides;
51 self
52 }
53
54 pub fn padding(mut self, padding: &str) -> Self {
56 self.padding = padding.to_string();
57 self
58 }
59
60 pub fn activation(mut self, activation: ActivationFunction) -> Self {
62 self.activation = Some(activation);
63 self
64 }
65
66 pub fn use_bias(mut self, use_bias: bool) -> Self {
68 self.use_bias = use_bias;
69 self
70 }
71
72 pub fn name(mut self, name: &str) -> Self {
74 self.layer_name = Some(name.to_string());
75 self
76 }
77}
78
79impl KerasLayer for Conv2D {
80 fn call(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
81 if !self.built {
82 return Err(MLError::ModelNotTrained(
83 "Layer not built. Call build() first.".to_string(),
84 ));
85 }
86
87 let kernel = self
88 .kernel
89 .as_ref()
90 .ok_or_else(|| MLError::ModelNotTrained("Conv2D kernel not initialized".to_string()))?;
91
92 let shape = input.shape();
93 let (batch, height, width, _in_channels) = (shape[0], shape[1], shape[2], shape[3]);
94
95 let (pad_h, pad_w) = if self.padding == "same" {
96 (self.kernel_size.0 / 2, self.kernel_size.1 / 2)
97 } else {
98 (0, 0)
99 };
100
101 let out_h = (height + 2 * pad_h - self.kernel_size.0) / self.strides.0 + 1;
102 let out_w = (width + 2 * pad_w - self.kernel_size.1) / self.strides.1 + 1;
103
104 let mut output = ArrayD::zeros(IxDyn(&[batch, out_h, out_w, self.filters]));
105
106 for b in 0..batch {
107 for oh in 0..out_h {
108 for ow in 0..out_w {
109 for f in 0..self.filters {
110 let mut sum = if self.use_bias {
111 self.bias.as_ref().map_or(0.0, |bias| bias[[f]])
112 } else {
113 0.0
114 };
115
116 for kh in 0..self.kernel_size.0 {
117 for kw in 0..self.kernel_size.1 {
118 let ih = oh * self.strides.0 + kh;
119 let iw = ow * self.strides.1 + kw;
120 if ih < height && iw < width {
121 for ic in 0..shape[3] {
122 sum += input[[b, ih, iw, ic]] * kernel[[kh, kw, ic, f]];
123 }
124 }
125 }
126 }
127 output[[b, oh, ow, f]] = sum;
128 }
129 }
130 }
131 }
132
133 if let Some(ref activation) = self.activation {
134 output = output.mapv(|x| match activation {
135 ActivationFunction::ReLU => x.max(0.0),
136 ActivationFunction::Sigmoid => 1.0 / (1.0 + (-x).exp()),
137 ActivationFunction::Tanh => x.tanh(),
138 ActivationFunction::Softmax => x,
139 ActivationFunction::LeakyReLU(alpha) => {
140 if x > 0.0 {
141 x
142 } else {
143 alpha * x
144 }
145 }
146 ActivationFunction::ELU(alpha) => {
147 if x > 0.0 {
148 x
149 } else {
150 alpha * (x.exp() - 1.0)
151 }
152 }
153 ActivationFunction::Linear => x,
154 });
155 }
156
157 Ok(output)
158 }
159
160 fn build(&mut self, input_shape: &[usize]) -> Result<()> {
161 let in_channels = *input_shape
162 .last()
163 .ok_or_else(|| MLError::InvalidConfiguration("Invalid input shape".to_string()))?;
164
165 let scale = (2.0 / ((self.kernel_size.0 * self.kernel_size.1 * in_channels) as f64)).sqrt();
166 let kernel = ArrayD::from_shape_fn(
167 IxDyn(&[
168 self.kernel_size.0,
169 self.kernel_size.1,
170 in_channels,
171 self.filters,
172 ]),
173 |_| fastrand::f64() * 2.0 * scale - scale,
174 );
175
176 self.kernel = Some(kernel);
177
178 if self.use_bias {
179 self.bias = Some(ArrayD::zeros(IxDyn(&[self.filters])));
180 }
181
182 self.built = true;
183 Ok(())
184 }
185
186 fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
187 let (height, width) = (input_shape[1], input_shape[2]);
188 let (pad_h, pad_w) = if self.padding == "same" {
189 (self.kernel_size.0 / 2, self.kernel_size.1 / 2)
190 } else {
191 (0, 0)
192 };
193 let out_h = (height + 2 * pad_h - self.kernel_size.0) / self.strides.0 + 1;
194 let out_w = (width + 2 * pad_w - self.kernel_size.1) / self.strides.1 + 1;
195 vec![input_shape[0], out_h, out_w, self.filters]
196 }
197
198 fn count_params(&self) -> usize {
199 let kernel_params = self.kernel.as_ref().map_or(0, |k| k.len());
200 let bias_params = self.bias.as_ref().map_or(0, |b| b.len());
201 kernel_params + bias_params
202 }
203
204 fn get_weights(&self) -> Vec<ArrayD<f64>> {
205 let mut weights = vec![];
206 if let Some(ref k) = self.kernel {
207 weights.push(k.clone());
208 }
209 if let Some(ref b) = self.bias {
210 weights.push(b.clone());
211 }
212 weights
213 }
214
215 fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()> {
216 if !weights.is_empty() {
217 self.kernel = Some(weights[0].clone());
218 }
219 if weights.len() > 1 {
220 self.bias = Some(weights[1].clone());
221 }
222 Ok(())
223 }
224
225 fn built(&self) -> bool {
226 self.built
227 }
228
229 fn name(&self) -> &str {
230 self.layer_name.as_deref().unwrap_or("conv2d")
231 }
232}
233
234pub struct MaxPooling2D {
236 pool_size: (usize, usize),
238 strides: (usize, usize),
240 padding: String,
242 built: bool,
244 layer_name: Option<String>,
246}
247
248impl MaxPooling2D {
249 pub fn new(pool_size: (usize, usize)) -> Self {
251 Self {
252 pool_size,
253 strides: pool_size,
254 padding: "valid".to_string(),
255 built: false,
256 layer_name: None,
257 }
258 }
259
260 pub fn strides(mut self, strides: (usize, usize)) -> Self {
262 self.strides = strides;
263 self
264 }
265
266 pub fn padding(mut self, padding: &str) -> Self {
268 self.padding = padding.to_string();
269 self
270 }
271
272 pub fn name(mut self, name: &str) -> Self {
274 self.layer_name = Some(name.to_string());
275 self
276 }
277}
278
279impl KerasLayer for MaxPooling2D {
280 fn call(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
281 let shape = input.shape();
282 let (batch, height, width, channels) = (shape[0], shape[1], shape[2], shape[3]);
283
284 let out_h = (height - self.pool_size.0) / self.strides.0 + 1;
285 let out_w = (width - self.pool_size.1) / self.strides.1 + 1;
286
287 let mut output = ArrayD::zeros(IxDyn(&[batch, out_h, out_w, channels]));
288
289 for b in 0..batch {
290 for oh in 0..out_h {
291 for ow in 0..out_w {
292 for c in 0..channels {
293 let mut max_val = f64::NEG_INFINITY;
294 for ph in 0..self.pool_size.0 {
295 for pw in 0..self.pool_size.1 {
296 let ih = oh * self.strides.0 + ph;
297 let iw = ow * self.strides.1 + pw;
298 if ih < height && iw < width {
299 max_val = max_val.max(input[[b, ih, iw, c]]);
300 }
301 }
302 }
303 output[[b, oh, ow, c]] = max_val;
304 }
305 }
306 }
307 }
308
309 Ok(output)
310 }
311
312 fn build(&mut self, _input_shape: &[usize]) -> Result<()> {
313 self.built = true;
314 Ok(())
315 }
316
317 fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
318 let out_h = (input_shape[1] - self.pool_size.0) / self.strides.0 + 1;
319 let out_w = (input_shape[2] - self.pool_size.1) / self.strides.1 + 1;
320 vec![input_shape[0], out_h, out_w, input_shape[3]]
321 }
322
323 fn count_params(&self) -> usize {
324 0
325 }
326
327 fn get_weights(&self) -> Vec<ArrayD<f64>> {
328 vec![]
329 }
330
331 fn set_weights(&mut self, _weights: Vec<ArrayD<f64>>) -> Result<()> {
332 Ok(())
333 }
334
335 fn built(&self) -> bool {
336 self.built
337 }
338
339 fn name(&self) -> &str {
340 self.layer_name.as_deref().unwrap_or("max_pooling2d")
341 }
342}
343
344pub struct GlobalAveragePooling2D {
346 built: bool,
348 layer_name: Option<String>,
350}
351
352impl GlobalAveragePooling2D {
353 pub fn new() -> Self {
355 Self {
356 built: false,
357 layer_name: None,
358 }
359 }
360
361 pub fn name(mut self, name: &str) -> Self {
363 self.layer_name = Some(name.to_string());
364 self
365 }
366}
367
368impl Default for GlobalAveragePooling2D {
369 fn default() -> Self {
370 Self::new()
371 }
372}
373
374impl KerasLayer for GlobalAveragePooling2D {
375 fn call(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
376 let shape = input.shape();
377 let (batch, height, width, channels) = (shape[0], shape[1], shape[2], shape[3]);
378
379 let mut output = ArrayD::zeros(IxDyn(&[batch, channels]));
380 let count = (height * width) as f64;
381
382 for b in 0..batch {
383 for c in 0..channels {
384 let mut sum = 0.0;
385 for h in 0..height {
386 for w in 0..width {
387 sum += input[[b, h, w, c]];
388 }
389 }
390 output[[b, c]] = sum / count;
391 }
392 }
393
394 Ok(output)
395 }
396
397 fn build(&mut self, _input_shape: &[usize]) -> Result<()> {
398 self.built = true;
399 Ok(())
400 }
401
402 fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
403 vec![input_shape[0], input_shape[3]]
404 }
405
406 fn count_params(&self) -> usize {
407 0
408 }
409
410 fn get_weights(&self) -> Vec<ArrayD<f64>> {
411 vec![]
412 }
413
414 fn set_weights(&mut self, _weights: Vec<ArrayD<f64>>) -> Result<()> {
415 Ok(())
416 }
417
418 fn built(&self) -> bool {
419 self.built
420 }
421
422 fn name(&self) -> &str {
423 self.layer_name
424 .as_deref()
425 .unwrap_or("global_average_pooling2d")
426 }
427}