KerasLayer

Trait KerasLayer 

Source
pub trait KerasLayer: Send + Sync {
    // Required methods
    fn build(&mut self, input_shape: &[usize]) -> Result<()>;
    fn call(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>>;
    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize>;
    fn name(&self) -> &str;
    fn get_weights(&self) -> Vec<ArrayD<f64>> ;
    fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()>;
    fn built(&self) -> bool;

    // Provided method
    fn count_params(&self) -> usize { ... }
}
Expand description

Keras-style layer trait

Required Methods§

Source

fn build(&mut self, input_shape: &[usize]) -> Result<()>

Build the layer (called during model compilation)

Source

fn call(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>>

Forward pass through the layer

Source

fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize>

Compute output shape given input shape

Source

fn name(&self) -> &str

Get layer name

Source

fn get_weights(&self) -> Vec<ArrayD<f64>>

Get trainable parameters

Source

fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()>

Set trainable parameters

Source

fn built(&self) -> bool

Check if layer is built

Provided Methods§

Source

fn count_params(&self) -> usize

Get number of parameters

Implementors§