use std::ptr;
use flodl_sys::{self as ffi, FlodlTensor};
use super::{Tensor, check_err, Result};
pub struct RnnParams {
handle: *mut std::os::raw::c_void,
}
impl RnnParams {
pub fn new(
params: &[Tensor], mode: i64, num_layers: i64,
batch_first: bool, flatten: bool,
) -> Result<Self> {
let handles: Vec<FlodlTensor> = params.iter().map(|t| t.handle).collect();
let mut out: *mut std::os::raw::c_void = ptr::null_mut();
let err = unsafe {
ffi::flodl_rnn_params_create(
handles.as_ptr(), handles.len() as i64,
mode, num_layers, batch_first, flatten,
&mut out,
)
};
check_err(err)?;
Ok(RnnParams { handle: out })
}
}
impl Drop for RnnParams {
fn drop(&mut self) {
unsafe { ffi::flodl_rnn_params_free(self.handle) }
}
}
impl Tensor {
pub fn native_layer_norm(
&self, weight: &Tensor, bias: &Tensor, normalized_size: i64, eps: f64,
) -> Result<(Tensor, Tensor, Tensor)> {
let mut out: FlodlTensor = ptr::null_mut();
let mut mean: FlodlTensor = ptr::null_mut();
let mut rstd: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_native_layer_norm(
self.handle, weight.handle, bias.handle,
normalized_size, eps,
&mut out, &mut mean, &mut rstd,
)
};
check_err(err)?;
Ok((Tensor::from_raw(out), Tensor::from_raw(mean), Tensor::from_raw(rstd)))
}
#[allow(clippy::too_many_arguments)]
pub fn conv2d(
&self, weight: &Tensor, bias: Option<&Tensor>,
stride: [i64; 2], padding: [i64; 2], dilation: [i64; 2], groups: i64,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let mut stride = stride;
let mut padding = padding;
let mut dilation = dilation;
let bias_handle = bias.map_or(ptr::null_mut(), |b| b.handle);
let err = unsafe {
ffi::flodl_conv2d(
self.handle, weight.handle, bias_handle,
stride.as_mut_ptr(), padding.as_mut_ptr(), dilation.as_mut_ptr(),
groups, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
#[allow(clippy::too_many_arguments)]
pub fn conv_transpose2d(
&self, weight: &Tensor, bias: Option<&Tensor>,
stride: [i64; 2], padding: [i64; 2], output_padding: [i64; 2],
dilation: [i64; 2], groups: i64,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let mut stride = stride;
let mut padding = padding;
let mut output_padding = output_padding;
let mut dilation = dilation;
let bias_handle = bias.map_or(ptr::null_mut(), |b| b.handle);
let err = unsafe {
ffi::flodl_conv_transpose2d(
self.handle, weight.handle, bias_handle,
stride.as_mut_ptr(), padding.as_mut_ptr(),
output_padding.as_mut_ptr(), dilation.as_mut_ptr(),
groups, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
#[allow(clippy::too_many_arguments)]
pub fn conv1d(
&self, weight: &Tensor, bias: Option<&Tensor>,
stride: i64, padding: i64, dilation: i64, groups: i64,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let bias_handle = bias.map_or(ptr::null_mut(), |b| b.handle);
let err = unsafe {
ffi::flodl_conv1d(
self.handle, weight.handle, bias_handle,
stride, padding, dilation,
groups, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
#[allow(clippy::too_many_arguments)]
pub fn conv_transpose1d(
&self, weight: &Tensor, bias: Option<&Tensor>,
stride: i64, padding: i64, output_padding: i64,
dilation: i64, groups: i64,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let bias_handle = bias.map_or(ptr::null_mut(), |b| b.handle);
let err = unsafe {
ffi::flodl_conv_transpose1d(
self.handle, weight.handle, bias_handle,
stride, padding, output_padding, dilation,
groups, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn group_norm(
&self, num_groups: i64,
weight: Option<&Tensor>, bias: Option<&Tensor>,
eps: f64,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let w = weight.map_or(ptr::null_mut(), |t| t.handle);
let b = bias.map_or(ptr::null_mut(), |t| t.handle);
let err = unsafe {
ffi::flodl_group_norm(
self.handle, num_groups, w, b, eps, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn linear(&self, weight: &Tensor, bias: Option<&Tensor>) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let bias_handle = bias.map_or(ptr::null_mut(), |b| b.handle);
let err = unsafe {
ffi::flodl_linear(self.handle, weight.handle, bias_handle, &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
#[allow(clippy::too_many_arguments)]
pub fn gru_cell(
&self, hx: &Tensor,
w_ih: &Tensor, w_hh: &Tensor,
b_ih: &Tensor, b_hh: &Tensor,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_gru_cell(
self.handle, hx.handle,
w_ih.handle, w_hh.handle,
b_ih.handle, b_hh.handle,
&mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
#[allow(clippy::too_many_arguments)]
pub fn lstm_cell(
&self, hx: &Tensor, cx: &Tensor,
w_ih: &Tensor, w_hh: &Tensor,
b_ih: &Tensor, b_hh: &Tensor,
) -> Result<(Tensor, Tensor)> {
let mut h_out: FlodlTensor = ptr::null_mut();
let mut c_out: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_lstm_cell(
self.handle, hx.handle, cx.handle,
w_ih.handle, w_hh.handle,
b_ih.handle, b_hh.handle,
&mut h_out, &mut c_out,
)
};
check_err(err)?;
Ok((Tensor::from_raw(h_out), Tensor::from_raw(c_out)))
}
pub fn lstm_seq(
&self, h_0: &Tensor, c_0: &Tensor,
params: &[Tensor], num_layers: i64, batch_first: bool, flatten: bool,
) -> Result<(Tensor, Tensor, Tensor)> {
let handles: Vec<FlodlTensor> = params.iter().map(|t| t.handle).collect();
let mut output: FlodlTensor = ptr::null_mut();
let mut h_n: FlodlTensor = ptr::null_mut();
let mut c_n: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_lstm(
self.handle, h_0.handle, c_0.handle,
handles.as_ptr(), handles.len() as i64,
num_layers, batch_first, flatten,
&mut output, &mut h_n, &mut c_n,
)
};
check_err(err)?;
Ok((Tensor::from_raw(output), Tensor::from_raw(h_n), Tensor::from_raw(c_n)))
}
pub fn gru_seq(
&self, h_0: &Tensor,
params: &[Tensor], num_layers: i64, batch_first: bool, flatten: bool,
) -> Result<(Tensor, Tensor)> {
let handles: Vec<FlodlTensor> = params.iter().map(|t| t.handle).collect();
let mut output: FlodlTensor = ptr::null_mut();
let mut h_n: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_gru(
self.handle, h_0.handle,
handles.as_ptr(), handles.len() as i64,
num_layers, batch_first, flatten,
&mut output, &mut h_n,
)
};
check_err(err)?;
Ok((Tensor::from_raw(output), Tensor::from_raw(h_n)))
}
pub fn lstm_seq_cached(
&self, h_0: &Tensor, c_0: &Tensor,
params: &RnnParams, num_layers: i64, batch_first: bool,
) -> Result<(Tensor, Tensor, Tensor)> {
let mut output: FlodlTensor = ptr::null_mut();
let mut h_n: FlodlTensor = ptr::null_mut();
let mut c_n: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_lstm_cached(
self.handle, h_0.handle, c_0.handle,
params.handle, num_layers, batch_first,
&mut output, &mut h_n, &mut c_n,
)
};
check_err(err)?;
Ok((Tensor::from_raw(output), Tensor::from_raw(h_n), Tensor::from_raw(c_n)))
}
pub fn gru_seq_cached(
&self, h_0: &Tensor,
params: &RnnParams, num_layers: i64, batch_first: bool,
) -> Result<(Tensor, Tensor)> {
let mut output: FlodlTensor = ptr::null_mut();
let mut h_n: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_gru_cached(
self.handle, h_0.handle,
params.handle, num_layers, batch_first,
&mut output, &mut h_n,
)
};
check_err(err)?;
Ok((Tensor::from_raw(output), Tensor::from_raw(h_n)))
}
pub fn max_pool2d(
&self,
kernel_size: [i64; 2],
stride: [i64; 2],
padding: [i64; 2],
dilation: [i64; 2],
ceil_mode: bool,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let mut ks = kernel_size;
let mut st = stride;
let mut pd = padding;
let mut dl = dilation;
let err = unsafe {
ffi::flodl_max_pool2d(
self.handle,
ks.as_mut_ptr(), st.as_mut_ptr(),
pd.as_mut_ptr(), dl.as_mut_ptr(),
ceil_mode as i32, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn avg_pool2d(
&self,
kernel_size: [i64; 2],
stride: [i64; 2],
padding: [i64; 2],
ceil_mode: bool,
count_include_pad: bool,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let mut ks = kernel_size;
let mut st = stride;
let mut pd = padding;
let err = unsafe {
ffi::flodl_avg_pool2d(
self.handle,
ks.as_mut_ptr(), st.as_mut_ptr(), pd.as_mut_ptr(),
ceil_mode as i32, count_include_pad as i32,
&mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn adaptive_avg_pool2d(&self, output_size: [i64; 2]) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let mut os = output_size;
let err = unsafe {
ffi::flodl_adaptive_avg_pool2d(self.handle, os.as_mut_ptr(), &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn grid_sample(
&self, grid: &Tensor, mode: i32, padding_mode: i32, align_corners: bool,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_grid_sample(
self.handle, grid.handle, mode, padding_mode,
align_corners as i32, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn mse_loss(&self, target: &Tensor, reduction: i64) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_mse_loss(self.handle, target.handle, reduction, &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
#[allow(clippy::too_many_arguments)]
pub fn cross_entropy_loss(
&self, target: &Tensor, reduction: i64,
ignore_index: i64, label_smoothing: f64,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_cross_entropy_loss(
self.handle, target.handle,
reduction, ignore_index, label_smoothing,
&mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn bce_loss(&self, target: &Tensor, reduction: i64) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_bce_loss(
self.handle, target.handle, reduction, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn bce_with_logits_loss(&self, target: &Tensor, reduction: i64) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_bce_with_logits_loss(
self.handle, target.handle, reduction, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn l1_loss(&self, target: &Tensor, reduction: i64) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_l1_loss(self.handle, target.handle, reduction, &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn smooth_l1_loss(&self, target: &Tensor, reduction: i64, beta: f64) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_smooth_l1_loss(
self.handle, target.handle, reduction, beta, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn kl_div_loss(&self, target: &Tensor, reduction: i64, log_target: bool) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_kl_div_loss(
self.handle, target.handle, reduction, log_target as i32, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn nll_loss(&self, target: &Tensor, reduction: i64, ignore_index: i64) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_nll_loss(self.handle, target.handle, reduction, ignore_index, &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn ctc_loss(
&self, targets: &Tensor, input_lengths: &Tensor, target_lengths: &Tensor,
blank: i64, reduction: i64,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_ctc_loss(
self.handle, targets.handle,
input_lengths.handle, target_lengths.handle,
blank, reduction, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
#[allow(clippy::too_many_arguments)]
pub fn batch_norm(
&self, weight: Option<&Tensor>, bias: Option<&Tensor>,
running_mean: Option<&Tensor>, running_var: Option<&Tensor>,
training: bool, momentum: f64, eps: f64,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let w = weight.map_or(ptr::null_mut(), |t| t.handle);
let b = bias.map_or(ptr::null_mut(), |t| t.handle);
let rm = running_mean.map_or(ptr::null_mut(), |t| t.handle);
let rv = running_var.map_or(ptr::null_mut(), |t| t.handle);
let err = unsafe {
ffi::flodl_batch_norm(
self.handle, w, b, rm, rv,
training as i32, momentum, eps, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn dropout(&self, p: f64, training: bool) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_dropout(self.handle, p, training as i32, &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn feature_dropout(&self, p: f64, training: bool) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_feature_dropout(self.handle, p, training as i32, &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn embedding_bag(
weight: &Tensor, indices: &Tensor, offsets: &Tensor, mode: i64,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_embedding_bag(
weight.handle, indices.handle, offsets.handle,
mode, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn interpolate(
&self, output_size: &[i64], mode: i32, align_corners: bool,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let mut os = output_size.to_vec();
let err = unsafe {
ffi::flodl_interpolate(
self.handle, os.as_mut_ptr(), os.len() as i32,
mode, align_corners as i32, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn im2col(
&self, kernel_size: [i64; 2], dilation: [i64; 2],
padding: [i64; 2], stride: [i64; 2],
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let mut ks = kernel_size;
let mut dl = dilation;
let mut pd = padding;
let mut st = stride;
let err = unsafe {
ffi::flodl_im2col(
self.handle, ks.as_mut_ptr(), dl.as_mut_ptr(),
pd.as_mut_ptr(), st.as_mut_ptr(), &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn col2im(
&self, output_size: [i64; 2], kernel_size: [i64; 2],
dilation: [i64; 2], padding: [i64; 2], stride: [i64; 2],
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let mut os = output_size;
let mut ks = kernel_size;
let mut dl = dilation;
let mut pd = padding;
let mut st = stride;
let err = unsafe {
ffi::flodl_col2im(
self.handle, os.as_mut_ptr(), ks.as_mut_ptr(),
dl.as_mut_ptr(), pd.as_mut_ptr(), st.as_mut_ptr(),
&mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
#[allow(clippy::too_many_arguments)]
pub fn conv3d(
&self, weight: &Tensor, bias: Option<&Tensor>,
stride: [i64; 3], padding: [i64; 3], dilation: [i64; 3], groups: i64,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let mut stride = stride;
let mut padding = padding;
let mut dilation = dilation;
let bias_handle = bias.map_or(ptr::null_mut(), |b| b.handle);
let err = unsafe {
ffi::flodl_conv3d(
self.handle, weight.handle, bias_handle,
stride.as_mut_ptr(), padding.as_mut_ptr(), dilation.as_mut_ptr(),
groups, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
#[allow(clippy::too_many_arguments)]
pub fn conv_transpose3d(
&self, weight: &Tensor, bias: Option<&Tensor>,
stride: [i64; 3], padding: [i64; 3], output_padding: [i64; 3],
dilation: [i64; 3], groups: i64,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let mut stride = stride;
let mut padding = padding;
let mut output_padding = output_padding;
let mut dilation = dilation;
let bias_handle = bias.map_or(ptr::null_mut(), |b| b.handle);
let err = unsafe {
ffi::flodl_conv_transpose3d(
self.handle, weight.handle, bias_handle,
stride.as_mut_ptr(), padding.as_mut_ptr(),
output_padding.as_mut_ptr(), dilation.as_mut_ptr(),
groups, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn max_pool1d(
&self, kernel_size: i64, stride: i64, padding: i64, dilation: i64, ceil_mode: bool,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_max_pool1d(
self.handle, kernel_size, stride, padding, dilation,
ceil_mode as i32, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn avg_pool1d(
&self, kernel_size: i64, stride: i64, padding: i64,
ceil_mode: bool, count_include_pad: bool,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_avg_pool1d(
self.handle, kernel_size, stride, padding,
ceil_mode as i32, count_include_pad as i32, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn adaptive_max_pool2d(&self, output_size: [i64; 2]) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let mut os = output_size;
let err = unsafe {
ffi::flodl_adaptive_max_pool2d(self.handle, os.as_mut_ptr(), &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
#[allow(clippy::too_many_arguments)]
pub fn instance_norm(
&self, weight: Option<&Tensor>, bias: Option<&Tensor>,
running_mean: Option<&Tensor>, running_var: Option<&Tensor>,
use_input_stats: bool, momentum: f64, eps: f64,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let w = weight.map_or(ptr::null_mut(), |t| t.handle);
let b = bias.map_or(ptr::null_mut(), |t| t.handle);
let rm = running_mean.map_or(ptr::null_mut(), |t| t.handle);
let rv = running_var.map_or(ptr::null_mut(), |t| t.handle);
let err = unsafe {
ffi::flodl_instance_norm(
self.handle, w, b, rm, rv,
use_input_stats as i32, momentum, eps, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn pixel_shuffle(&self, upscale_factor: i64) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_pixel_shuffle(self.handle, upscale_factor, &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn pixel_unshuffle(&self, downscale_factor: i64) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_pixel_unshuffle(self.handle, downscale_factor, &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn bilinear(
input1: &Tensor, input2: &Tensor, weight: &Tensor, bias: Option<&Tensor>,
) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let bias_handle = bias.map_or(ptr::null_mut(), |b| b.handle);
let err = unsafe {
ffi::flodl_bilinear(
input1.handle, input2.handle, weight.handle, bias_handle,
&mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
}