quantrs2_ml/keras_api/
conv.rs

1//! Convolutional layers for Keras-like API
2
3use super::{ActivationFunction, KerasLayer};
4use crate::error::{MLError, Result};
5use scirs2_core::ndarray::{ArrayD, IxDyn};
6
7/// Conv2D layer (Keras-compatible)
8pub struct Conv2D {
9    /// Number of filters
10    filters: usize,
11    /// Kernel size
12    kernel_size: (usize, usize),
13    /// Stride
14    strides: (usize, usize),
15    /// Padding
16    padding: String,
17    /// Activation
18    activation: Option<ActivationFunction>,
19    /// Use bias
20    use_bias: bool,
21    /// Weights
22    kernel: Option<ArrayD<f64>>,
23    /// Bias
24    bias: Option<ArrayD<f64>>,
25    /// Built flag
26    built: bool,
27    /// Layer name
28    layer_name: Option<String>,
29}
30
31impl Conv2D {
32    /// Create new Conv2D layer
33    pub fn new(filters: usize, kernel_size: (usize, usize)) -> Self {
34        Self {
35            filters,
36            kernel_size,
37            strides: (1, 1),
38            padding: "valid".to_string(),
39            activation: None,
40            use_bias: true,
41            kernel: None,
42            bias: None,
43            built: false,
44            layer_name: None,
45        }
46    }
47
48    /// Set strides
49    pub fn strides(mut self, strides: (usize, usize)) -> Self {
50        self.strides = strides;
51        self
52    }
53
54    /// Set padding
55    pub fn padding(mut self, padding: &str) -> Self {
56        self.padding = padding.to_string();
57        self
58    }
59
60    /// Set activation
61    pub fn activation(mut self, activation: ActivationFunction) -> Self {
62        self.activation = Some(activation);
63        self
64    }
65
66    /// Set use bias
67    pub fn use_bias(mut self, use_bias: bool) -> Self {
68        self.use_bias = use_bias;
69        self
70    }
71
72    /// Set layer name
73    pub fn name(mut self, name: &str) -> Self {
74        self.layer_name = Some(name.to_string());
75        self
76    }
77}
78
79impl KerasLayer for Conv2D {
80    fn call(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
81        if !self.built {
82            return Err(MLError::ModelNotTrained(
83                "Layer not built. Call build() first.".to_string(),
84            ));
85        }
86
87        let kernel = self
88            .kernel
89            .as_ref()
90            .ok_or_else(|| MLError::ModelNotTrained("Conv2D kernel not initialized".to_string()))?;
91
92        let shape = input.shape();
93        let (batch, height, width, _in_channels) = (shape[0], shape[1], shape[2], shape[3]);
94
95        let (pad_h, pad_w) = if self.padding == "same" {
96            (self.kernel_size.0 / 2, self.kernel_size.1 / 2)
97        } else {
98            (0, 0)
99        };
100
101        let out_h = (height + 2 * pad_h - self.kernel_size.0) / self.strides.0 + 1;
102        let out_w = (width + 2 * pad_w - self.kernel_size.1) / self.strides.1 + 1;
103
104        let mut output = ArrayD::zeros(IxDyn(&[batch, out_h, out_w, self.filters]));
105
106        for b in 0..batch {
107            for oh in 0..out_h {
108                for ow in 0..out_w {
109                    for f in 0..self.filters {
110                        let mut sum = if self.use_bias {
111                            self.bias.as_ref().map_or(0.0, |bias| bias[[f]])
112                        } else {
113                            0.0
114                        };
115
116                        for kh in 0..self.kernel_size.0 {
117                            for kw in 0..self.kernel_size.1 {
118                                let ih = oh * self.strides.0 + kh;
119                                let iw = ow * self.strides.1 + kw;
120                                if ih < height && iw < width {
121                                    for ic in 0..shape[3] {
122                                        sum += input[[b, ih, iw, ic]] * kernel[[kh, kw, ic, f]];
123                                    }
124                                }
125                            }
126                        }
127                        output[[b, oh, ow, f]] = sum;
128                    }
129                }
130            }
131        }
132
133        if let Some(ref activation) = self.activation {
134            output = output.mapv(|x| match activation {
135                ActivationFunction::ReLU => x.max(0.0),
136                ActivationFunction::Sigmoid => 1.0 / (1.0 + (-x).exp()),
137                ActivationFunction::Tanh => x.tanh(),
138                ActivationFunction::Softmax => x,
139                ActivationFunction::LeakyReLU(alpha) => {
140                    if x > 0.0 {
141                        x
142                    } else {
143                        alpha * x
144                    }
145                }
146                ActivationFunction::ELU(alpha) => {
147                    if x > 0.0 {
148                        x
149                    } else {
150                        alpha * (x.exp() - 1.0)
151                    }
152                }
153                ActivationFunction::Linear => x,
154            });
155        }
156
157        Ok(output)
158    }
159
160    fn build(&mut self, input_shape: &[usize]) -> Result<()> {
161        let in_channels = *input_shape
162            .last()
163            .ok_or_else(|| MLError::InvalidConfiguration("Invalid input shape".to_string()))?;
164
165        let scale = (2.0 / ((self.kernel_size.0 * self.kernel_size.1 * in_channels) as f64)).sqrt();
166        let kernel = ArrayD::from_shape_fn(
167            IxDyn(&[
168                self.kernel_size.0,
169                self.kernel_size.1,
170                in_channels,
171                self.filters,
172            ]),
173            |_| fastrand::f64() * 2.0 * scale - scale,
174        );
175
176        self.kernel = Some(kernel);
177
178        if self.use_bias {
179            self.bias = Some(ArrayD::zeros(IxDyn(&[self.filters])));
180        }
181
182        self.built = true;
183        Ok(())
184    }
185
186    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
187        let (height, width) = (input_shape[1], input_shape[2]);
188        let (pad_h, pad_w) = if self.padding == "same" {
189            (self.kernel_size.0 / 2, self.kernel_size.1 / 2)
190        } else {
191            (0, 0)
192        };
193        let out_h = (height + 2 * pad_h - self.kernel_size.0) / self.strides.0 + 1;
194        let out_w = (width + 2 * pad_w - self.kernel_size.1) / self.strides.1 + 1;
195        vec![input_shape[0], out_h, out_w, self.filters]
196    }
197
198    fn count_params(&self) -> usize {
199        let kernel_params = self.kernel.as_ref().map_or(0, |k| k.len());
200        let bias_params = self.bias.as_ref().map_or(0, |b| b.len());
201        kernel_params + bias_params
202    }
203
204    fn get_weights(&self) -> Vec<ArrayD<f64>> {
205        let mut weights = vec![];
206        if let Some(ref k) = self.kernel {
207            weights.push(k.clone());
208        }
209        if let Some(ref b) = self.bias {
210            weights.push(b.clone());
211        }
212        weights
213    }
214
215    fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()> {
216        if !weights.is_empty() {
217            self.kernel = Some(weights[0].clone());
218        }
219        if weights.len() > 1 {
220            self.bias = Some(weights[1].clone());
221        }
222        Ok(())
223    }
224
225    fn built(&self) -> bool {
226        self.built
227    }
228
229    fn name(&self) -> &str {
230        self.layer_name.as_deref().unwrap_or("conv2d")
231    }
232}
233
234/// MaxPooling2D layer
235pub struct MaxPooling2D {
236    /// Pool size
237    pool_size: (usize, usize),
238    /// Strides
239    strides: (usize, usize),
240    /// Padding
241    padding: String,
242    /// Built flag
243    built: bool,
244    /// Layer name
245    layer_name: Option<String>,
246}
247
248impl MaxPooling2D {
249    /// Create new MaxPooling2D layer
250    pub fn new(pool_size: (usize, usize)) -> Self {
251        Self {
252            pool_size,
253            strides: pool_size,
254            padding: "valid".to_string(),
255            built: false,
256            layer_name: None,
257        }
258    }
259
260    /// Set strides
261    pub fn strides(mut self, strides: (usize, usize)) -> Self {
262        self.strides = strides;
263        self
264    }
265
266    /// Set padding
267    pub fn padding(mut self, padding: &str) -> Self {
268        self.padding = padding.to_string();
269        self
270    }
271
272    /// Set layer name
273    pub fn name(mut self, name: &str) -> Self {
274        self.layer_name = Some(name.to_string());
275        self
276    }
277}
278
279impl KerasLayer for MaxPooling2D {
280    fn call(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
281        let shape = input.shape();
282        let (batch, height, width, channels) = (shape[0], shape[1], shape[2], shape[3]);
283
284        let out_h = (height - self.pool_size.0) / self.strides.0 + 1;
285        let out_w = (width - self.pool_size.1) / self.strides.1 + 1;
286
287        let mut output = ArrayD::zeros(IxDyn(&[batch, out_h, out_w, channels]));
288
289        for b in 0..batch {
290            for oh in 0..out_h {
291                for ow in 0..out_w {
292                    for c in 0..channels {
293                        let mut max_val = f64::NEG_INFINITY;
294                        for ph in 0..self.pool_size.0 {
295                            for pw in 0..self.pool_size.1 {
296                                let ih = oh * self.strides.0 + ph;
297                                let iw = ow * self.strides.1 + pw;
298                                if ih < height && iw < width {
299                                    max_val = max_val.max(input[[b, ih, iw, c]]);
300                                }
301                            }
302                        }
303                        output[[b, oh, ow, c]] = max_val;
304                    }
305                }
306            }
307        }
308
309        Ok(output)
310    }
311
312    fn build(&mut self, _input_shape: &[usize]) -> Result<()> {
313        self.built = true;
314        Ok(())
315    }
316
317    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
318        let out_h = (input_shape[1] - self.pool_size.0) / self.strides.0 + 1;
319        let out_w = (input_shape[2] - self.pool_size.1) / self.strides.1 + 1;
320        vec![input_shape[0], out_h, out_w, input_shape[3]]
321    }
322
323    fn count_params(&self) -> usize {
324        0
325    }
326
327    fn get_weights(&self) -> Vec<ArrayD<f64>> {
328        vec![]
329    }
330
331    fn set_weights(&mut self, _weights: Vec<ArrayD<f64>>) -> Result<()> {
332        Ok(())
333    }
334
335    fn built(&self) -> bool {
336        self.built
337    }
338
339    fn name(&self) -> &str {
340        self.layer_name.as_deref().unwrap_or("max_pooling2d")
341    }
342}
343
344/// GlobalAveragePooling2D layer
345pub struct GlobalAveragePooling2D {
346    /// Built flag
347    built: bool,
348    /// Layer name
349    layer_name: Option<String>,
350}
351
352impl GlobalAveragePooling2D {
353    /// Create new GlobalAveragePooling2D
354    pub fn new() -> Self {
355        Self {
356            built: false,
357            layer_name: None,
358        }
359    }
360
361    /// Set layer name
362    pub fn name(mut self, name: &str) -> Self {
363        self.layer_name = Some(name.to_string());
364        self
365    }
366}
367
368impl Default for GlobalAveragePooling2D {
369    fn default() -> Self {
370        Self::new()
371    }
372}
373
374impl KerasLayer for GlobalAveragePooling2D {
375    fn call(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
376        let shape = input.shape();
377        let (batch, height, width, channels) = (shape[0], shape[1], shape[2], shape[3]);
378
379        let mut output = ArrayD::zeros(IxDyn(&[batch, channels]));
380        let count = (height * width) as f64;
381
382        for b in 0..batch {
383            for c in 0..channels {
384                let mut sum = 0.0;
385                for h in 0..height {
386                    for w in 0..width {
387                        sum += input[[b, h, w, c]];
388                    }
389                }
390                output[[b, c]] = sum / count;
391            }
392        }
393
394        Ok(output)
395    }
396
397    fn build(&mut self, _input_shape: &[usize]) -> Result<()> {
398        self.built = true;
399        Ok(())
400    }
401
402    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
403        vec![input_shape[0], input_shape[3]]
404    }
405
406    fn count_params(&self) -> usize {
407        0
408    }
409
410    fn get_weights(&self) -> Vec<ArrayD<f64>> {
411        vec![]
412    }
413
414    fn set_weights(&mut self, _weights: Vec<ArrayD<f64>>) -> Result<()> {
415        Ok(())
416    }
417
418    fn built(&self) -> bool {
419        self.built
420    }
421
422    fn name(&self) -> &str {
423        self.layer_name
424            .as_deref()
425            .unwrap_or("global_average_pooling2d")
426    }
427}