quantrs2_ml/keras_api/
layers.rs

1//! Basic layers for Keras-like API
2
3use super::{ActivationFunction, InitializerType, KerasLayer};
4use crate::error::{MLError, Result};
5use scirs2_core::ndarray::{ArrayD, Axis, IxDyn};
6
7/// Dense (fully connected) layer
8pub struct Dense {
9    /// Number of units
10    units: usize,
11    /// Activation function
12    activation: Option<ActivationFunction>,
13    /// Use bias
14    use_bias: bool,
15    /// Kernel initializer
16    kernel_initializer: InitializerType,
17    /// Bias initializer
18    bias_initializer: InitializerType,
19    /// Layer name
20    name: String,
21    /// Built flag
22    built: bool,
23    /// Input shape
24    input_shape: Option<Vec<usize>>,
25    /// Weights (kernel and bias)
26    weights: Vec<ArrayD<f64>>,
27}
28
29impl Dense {
30    /// Create new dense layer
31    pub fn new(units: usize) -> Self {
32        Self {
33            units,
34            activation: None,
35            use_bias: true,
36            kernel_initializer: InitializerType::GlorotUniform,
37            bias_initializer: InitializerType::Zeros,
38            name: format!("dense_{}", fastrand::u32(..)),
39            built: false,
40            input_shape: None,
41            weights: Vec::new(),
42        }
43    }
44
45    /// Set activation function
46    pub fn activation(mut self, activation: ActivationFunction) -> Self {
47        self.activation = Some(activation);
48        self
49    }
50
51    /// Set use bias
52    pub fn use_bias(mut self, use_bias: bool) -> Self {
53        self.use_bias = use_bias;
54        self
55    }
56
57    /// Set layer name
58    pub fn name(mut self, name: impl Into<String>) -> Self {
59        self.name = name.into();
60        self
61    }
62
63    /// Set kernel initializer
64    pub fn kernel_initializer(mut self, initializer: InitializerType) -> Self {
65        self.kernel_initializer = initializer;
66        self
67    }
68
69    /// Initialize weights
70    fn initialize_weights(
71        &self,
72        shape: &[usize],
73        initializer: &InitializerType,
74    ) -> Result<ArrayD<f64>> {
75        match initializer {
76            InitializerType::Zeros => Ok(ArrayD::zeros(shape)),
77            InitializerType::Ones => Ok(ArrayD::ones(shape)),
78            InitializerType::GlorotUniform => {
79                let fan_in = if shape.len() >= 2 { shape[0] } else { 1 };
80                let fan_out = if shape.len() >= 2 { shape[1] } else { shape[0] };
81                let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
82
83                Ok(ArrayD::from_shape_fn(shape, |_| {
84                    fastrand::f64() * 2.0 * limit - limit
85                }))
86            }
87            InitializerType::GlorotNormal => {
88                let fan_in = if shape.len() >= 2 { shape[0] } else { 1 };
89                let fan_out = if shape.len() >= 2 { shape[1] } else { shape[0] };
90                let std = (2.0 / (fan_in + fan_out) as f64).sqrt();
91
92                Ok(ArrayD::from_shape_fn(shape, |_| {
93                    let u1 = fastrand::f64();
94                    let u2 = fastrand::f64();
95                    let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
96                    z * std
97                }))
98            }
99            InitializerType::HeUniform => {
100                let fan_in = if shape.len() >= 2 { shape[0] } else { 1 };
101                let limit = (6.0 / fan_in as f64).sqrt();
102
103                Ok(ArrayD::from_shape_fn(shape, |_| {
104                    fastrand::f64() * 2.0 * limit - limit
105                }))
106            }
107        }
108    }
109
110    /// Apply activation function
111    fn apply_activation(
112        &self,
113        inputs: &ArrayD<f64>,
114        activation: &ActivationFunction,
115    ) -> Result<ArrayD<f64>> {
116        Ok(match activation {
117            ActivationFunction::Linear => inputs.clone(),
118            ActivationFunction::ReLU => inputs.mapv(|x| x.max(0.0)),
119            ActivationFunction::Sigmoid => inputs.mapv(|x| 1.0 / (1.0 + (-x).exp())),
120            ActivationFunction::Tanh => inputs.mapv(|x| x.tanh()),
121            ActivationFunction::Softmax => {
122                let mut outputs = inputs.clone();
123                for mut row in outputs.axis_iter_mut(Axis(0)) {
124                    let max_val = row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
125                    row.mapv_inplace(|x| (x - max_val).exp());
126                    let sum = row.sum();
127                    row /= sum;
128                }
129                outputs
130            }
131            ActivationFunction::LeakyReLU(alpha) => {
132                inputs.mapv(|x| if x > 0.0 { x } else { alpha * x })
133            }
134            ActivationFunction::ELU(alpha) => {
135                inputs.mapv(|x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
136            }
137        })
138    }
139}
140
141impl KerasLayer for Dense {
142    fn build(&mut self, input_shape: &[usize]) -> Result<()> {
143        if input_shape.is_empty() {
144            return Err(MLError::InvalidConfiguration(
145                "Dense layer requires input shape".to_string(),
146            ));
147        }
148
149        let input_dim = input_shape[input_shape.len() - 1];
150        self.input_shape = Some(input_shape.to_vec());
151
152        let kernel = self.initialize_weights(&[input_dim, self.units], &self.kernel_initializer)?;
153        self.weights.push(kernel);
154
155        if self.use_bias {
156            let bias = self.initialize_weights(&[self.units], &self.bias_initializer)?;
157            self.weights.push(bias);
158        }
159
160        self.built = true;
161        Ok(())
162    }
163
164    fn call(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>> {
165        if !self.built {
166            return Err(MLError::InvalidConfiguration(
167                "Layer must be built before calling".to_string(),
168            ));
169        }
170
171        let kernel = &self.weights[0];
172        let outputs = match (inputs.ndim(), kernel.ndim()) {
173            (2, 2) => {
174                let inputs_2d = inputs
175                    .clone()
176                    .into_dimensionality::<scirs2_core::ndarray::Ix2>()
177                    .map_err(|_| MLError::InvalidConfiguration("Input must be 2D".to_string()))?;
178                let kernel_2d = kernel
179                    .clone()
180                    .into_dimensionality::<scirs2_core::ndarray::Ix2>()
181                    .map_err(|_| MLError::InvalidConfiguration("Kernel must be 2D".to_string()))?;
182                inputs_2d.dot(&kernel_2d).into_dyn()
183            }
184            _ => {
185                return Err(MLError::InvalidConfiguration(
186                    "Unsupported array dimensions for matrix multiplication".to_string(),
187                ));
188            }
189        };
190        let mut outputs = outputs;
191
192        if self.use_bias && self.weights.len() > 1 {
193            let bias = &self.weights[1];
194            outputs = outputs + bias;
195        }
196
197        if let Some(ref activation) = self.activation {
198            outputs = self.apply_activation(&outputs, activation)?;
199        }
200
201        Ok(outputs)
202    }
203
204    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
205        let mut output_shape = input_shape.to_vec();
206        let last_idx = output_shape.len() - 1;
207        output_shape[last_idx] = self.units;
208        output_shape
209    }
210
211    fn name(&self) -> &str {
212        &self.name
213    }
214
215    fn get_weights(&self) -> Vec<ArrayD<f64>> {
216        self.weights.clone()
217    }
218
219    fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()> {
220        if weights.len() != self.weights.len() {
221            return Err(MLError::InvalidConfiguration(
222                "Number of weight arrays doesn't match layer structure".to_string(),
223            ));
224        }
225        self.weights = weights;
226        Ok(())
227    }
228
229    fn built(&self) -> bool {
230        self.built
231    }
232}
233
234/// Activation layer
235pub struct Activation {
236    /// Activation function
237    function: ActivationFunction,
238    /// Layer name
239    name: String,
240    /// Built flag
241    built: bool,
242}
243
244impl Activation {
245    /// Create new activation layer
246    pub fn new(function: ActivationFunction) -> Self {
247        Self {
248            function,
249            name: format!("activation_{}", fastrand::u32(..)),
250            built: false,
251        }
252    }
253
254    /// Set layer name
255    pub fn name(mut self, name: impl Into<String>) -> Self {
256        self.name = name.into();
257        self
258    }
259}
260
261impl KerasLayer for Activation {
262    fn build(&mut self, _input_shape: &[usize]) -> Result<()> {
263        self.built = true;
264        Ok(())
265    }
266
267    fn call(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>> {
268        Ok(match &self.function {
269            ActivationFunction::Linear => inputs.clone(),
270            ActivationFunction::ReLU => inputs.mapv(|x| x.max(0.0)),
271            ActivationFunction::Sigmoid => inputs.mapv(|x| 1.0 / (1.0 + (-x).exp())),
272            ActivationFunction::Tanh => inputs.mapv(|x| x.tanh()),
273            ActivationFunction::Softmax => {
274                let mut outputs = inputs.clone();
275                for mut row in outputs.axis_iter_mut(Axis(0)) {
276                    let max_val = row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
277                    row.mapv_inplace(|x| (x - max_val).exp());
278                    let sum = row.sum();
279                    row /= sum;
280                }
281                outputs
282            }
283            ActivationFunction::LeakyReLU(alpha) => {
284                inputs.mapv(|x| if x > 0.0 { x } else { alpha * x })
285            }
286            ActivationFunction::ELU(alpha) => {
287                inputs.mapv(|x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
288            }
289        })
290    }
291
292    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
293        input_shape.to_vec()
294    }
295
296    fn name(&self) -> &str {
297        &self.name
298    }
299
300    fn get_weights(&self) -> Vec<ArrayD<f64>> {
301        Vec::new()
302    }
303
304    fn set_weights(&mut self, _weights: Vec<ArrayD<f64>>) -> Result<()> {
305        Ok(())
306    }
307
308    fn built(&self) -> bool {
309        self.built
310    }
311}
312
313/// Dropout layer for regularization
314pub struct Dropout {
315    /// Dropout rate (0 to 1)
316    rate: f64,
317    /// Layer name
318    name: String,
319    /// Built flag
320    built: bool,
321    /// Training mode
322    training: bool,
323}
324
325impl Dropout {
326    /// Create new dropout layer
327    pub fn new(rate: f64) -> Self {
328        Self {
329            rate: rate.clamp(0.0, 1.0),
330            name: format!("dropout_{}", fastrand::u32(..)),
331            built: false,
332            training: true,
333        }
334    }
335
336    /// Set layer name
337    pub fn name(mut self, name: impl Into<String>) -> Self {
338        self.name = name.into();
339        self
340    }
341
342    /// Set training mode
343    pub fn set_training(&mut self, training: bool) {
344        self.training = training;
345    }
346}
347
348impl KerasLayer for Dropout {
349    fn build(&mut self, _input_shape: &[usize]) -> Result<()> {
350        self.built = true;
351        Ok(())
352    }
353
354    fn call(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>> {
355        if !self.training || self.rate == 0.0 {
356            return Ok(inputs.clone());
357        }
358
359        let scale = 1.0 / (1.0 - self.rate);
360        let output = inputs.mapv(|x| {
361            if fastrand::f64() < self.rate {
362                0.0
363            } else {
364                x * scale
365            }
366        });
367
368        Ok(output)
369    }
370
371    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
372        input_shape.to_vec()
373    }
374
375    fn name(&self) -> &str {
376        &self.name
377    }
378
379    fn get_weights(&self) -> Vec<ArrayD<f64>> {
380        Vec::new()
381    }
382
383    fn set_weights(&mut self, _weights: Vec<ArrayD<f64>>) -> Result<()> {
384        Ok(())
385    }
386
387    fn built(&self) -> bool {
388        self.built
389    }
390}
391
392/// Batch normalization layer
393pub struct BatchNormalization {
394    /// Momentum for moving average
395    momentum: f64,
396    /// Epsilon for numerical stability
397    epsilon: f64,
398    /// Use center (beta)
399    center: bool,
400    /// Use scale (gamma)
401    scale: bool,
402    /// Layer name
403    name: String,
404    /// Built flag
405    built: bool,
406    /// Weights: [gamma, beta, moving_mean, moving_var]
407    weights: Vec<ArrayD<f64>>,
408    /// Training mode
409    training: bool,
410}
411
412impl BatchNormalization {
413    /// Create new batch normalization layer
414    pub fn new() -> Self {
415        Self {
416            momentum: 0.99,
417            epsilon: 1e-3,
418            center: true,
419            scale: true,
420            name: format!("batch_norm_{}", fastrand::u32(..)),
421            built: false,
422            weights: Vec::new(),
423            training: true,
424        }
425    }
426
427    /// Set momentum
428    pub fn momentum(mut self, momentum: f64) -> Self {
429        self.momentum = momentum;
430        self
431    }
432
433    /// Set epsilon
434    pub fn epsilon(mut self, epsilon: f64) -> Self {
435        self.epsilon = epsilon;
436        self
437    }
438
439    /// Set layer name
440    pub fn name(mut self, name: impl Into<String>) -> Self {
441        self.name = name.into();
442        self
443    }
444
445    /// Set training mode
446    pub fn set_training(&mut self, training: bool) {
447        self.training = training;
448    }
449}
450
451impl Default for BatchNormalization {
452    fn default() -> Self {
453        Self::new()
454    }
455}
456
457impl KerasLayer for BatchNormalization {
458    fn build(&mut self, input_shape: &[usize]) -> Result<()> {
459        let features = input_shape[input_shape.len() - 1];
460
461        let gamma = ArrayD::ones(IxDyn(&[features]));
462        self.weights.push(gamma);
463
464        let beta = ArrayD::zeros(IxDyn(&[features]));
465        self.weights.push(beta);
466
467        let moving_mean = ArrayD::zeros(IxDyn(&[features]));
468        self.weights.push(moving_mean);
469
470        let moving_var = ArrayD::ones(IxDyn(&[features]));
471        self.weights.push(moving_var);
472
473        self.built = true;
474        Ok(())
475    }
476
477    fn call(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>> {
478        if !self.built {
479            return Err(MLError::InvalidConfiguration("Layer not built".to_string()));
480        }
481
482        let gamma = &self.weights[0];
483        let beta = &self.weights[1];
484        let moving_mean = &self.weights[2];
485        let moving_var = &self.weights[3];
486
487        let shape = inputs.shape();
488        let features = shape[shape.len() - 1];
489
490        let mut output = inputs.clone();
491
492        for (i, val) in output.iter_mut().enumerate() {
493            let f = i % features;
494            let mean = moving_mean[[f]];
495            let var = moving_var[[f]];
496            let std = (var + self.epsilon).sqrt();
497
498            *val = (*val - mean) / std;
499
500            if self.scale {
501                *val *= gamma[[f]];
502            }
503            if self.center {
504                *val += beta[[f]];
505            }
506        }
507
508        Ok(output)
509    }
510
511    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
512        input_shape.to_vec()
513    }
514
515    fn name(&self) -> &str {
516        &self.name
517    }
518
519    fn get_weights(&self) -> Vec<ArrayD<f64>> {
520        self.weights.clone()
521    }
522
523    fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()> {
524        if weights.len() != 4 {
525            return Err(MLError::InvalidConfiguration(
526                "BatchNormalization requires 4 weight arrays".to_string(),
527            ));
528        }
529        self.weights = weights;
530        Ok(())
531    }
532
533    fn built(&self) -> bool {
534        self.built
535    }
536}
537
538/// Flatten layer to reshape inputs
539pub struct Flatten {
540    /// Layer name
541    name: String,
542    /// Built flag
543    built: bool,
544    /// Input shape
545    input_shape: Option<Vec<usize>>,
546}
547
548impl Flatten {
549    /// Create new flatten layer
550    pub fn new() -> Self {
551        Self {
552            name: format!("flatten_{}", fastrand::u32(..)),
553            built: false,
554            input_shape: None,
555        }
556    }
557
558    /// Set layer name
559    pub fn name(mut self, name: impl Into<String>) -> Self {
560        self.name = name.into();
561        self
562    }
563}
564
565impl Default for Flatten {
566    fn default() -> Self {
567        Self::new()
568    }
569}
570
571impl KerasLayer for Flatten {
572    fn build(&mut self, input_shape: &[usize]) -> Result<()> {
573        self.input_shape = Some(input_shape.to_vec());
574        self.built = true;
575        Ok(())
576    }
577
578    fn call(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>> {
579        let shape = inputs.shape();
580        if shape.is_empty() {
581            return Ok(inputs.clone());
582        }
583
584        let batch_size = shape[0];
585        let flat_size: usize = shape[1..].iter().product();
586
587        let output = inputs
588            .clone()
589            .into_shape(IxDyn(&[batch_size, flat_size]))
590            .map_err(|e| MLError::InvalidConfiguration(format!("Reshape failed: {}", e)))?;
591
592        Ok(output)
593    }
594
595    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
596        if input_shape.is_empty() {
597            return vec![];
598        }
599        let flat_size: usize = input_shape[1..].iter().product();
600        vec![input_shape[0], flat_size]
601    }
602
603    fn name(&self) -> &str {
604        &self.name
605    }
606
607    fn get_weights(&self) -> Vec<ArrayD<f64>> {
608        Vec::new()
609    }
610
611    fn set_weights(&mut self, _weights: Vec<ArrayD<f64>>) -> Result<()> {
612        Ok(())
613    }
614
615    fn built(&self) -> bool {
616        self.built
617    }
618}