use crate::tensor::{Result, Tensor};
use super::variable::Variable;
impl Variable {
pub fn add(&self, other: &Variable) -> Result<Variable> {
let result = self.data().add(&other.data())?;
Ok(Variable::wrap(result))
}
pub fn sub(&self, other: &Variable) -> Result<Variable> {
let result = self.data().sub(&other.data())?;
Ok(Variable::wrap(result))
}
pub fn mul(&self, other: &Variable) -> Result<Variable> {
let result = self.data().mul(&other.data())?;
Ok(Variable::wrap(result))
}
pub fn div(&self, other: &Variable) -> Result<Variable> {
let result = self.data().div(&other.data())?;
Ok(Variable::wrap(result))
}
pub fn matmul(&self, other: &Variable) -> Result<Variable> {
let result = self.data().matmul(&other.data())?;
Ok(Variable::wrap(result))
}
pub fn mul_scalar(&self, scalar: f64) -> Result<Variable> {
let result = self.data().mul_scalar(scalar)?;
Ok(Variable::wrap(result))
}
pub fn div_scalar(&self, scalar: f64) -> Result<Variable> {
let result = self.data().div_scalar(scalar)?;
Ok(Variable::wrap(result))
}
pub fn add_scalar(&self, scalar: f64) -> Result<Variable> {
let result = self.data().add_scalar(scalar)?;
Ok(Variable::wrap(result))
}
pub fn neg(&self) -> Result<Variable> {
let result = self.data().neg()?;
Ok(Variable::wrap(result))
}
pub fn relu(&self) -> Result<Variable> {
let result = self.data().relu()?;
Ok(Variable::wrap(result))
}
pub fn sigmoid(&self) -> Result<Variable> {
let result = self.data().sigmoid()?;
Ok(Variable::wrap(result))
}
pub fn tanh(&self) -> Result<Variable> {
let result = self.data().tanh()?;
Ok(Variable::wrap(result))
}
pub fn gelu(&self) -> Result<Variable> {
let result = self.data().gelu()?;
Ok(Variable::wrap(result))
}
pub fn gelu_tanh(&self) -> Result<Variable> {
let result = self.data().gelu_tanh()?;
Ok(Variable::wrap(result))
}
pub fn silu(&self) -> Result<Variable> {
let result = self.data().silu()?;
Ok(Variable::wrap(result))
}
pub fn leaky_relu(&self, negative_slope: f64) -> Result<Variable> {
let result = self.data().leaky_relu(negative_slope)?;
Ok(Variable::wrap(result))
}
pub fn elu(&self, alpha: f64) -> Result<Variable> {
let result = self.data().elu(alpha)?;
Ok(Variable::wrap(result))
}
pub fn softplus(&self, beta: f64, threshold: f64) -> Result<Variable> {
let result = self.data().softplus(beta, threshold)?;
Ok(Variable::wrap(result))
}
pub fn mish(&self) -> Result<Variable> {
let result = self.data().mish()?;
Ok(Variable::wrap(result))
}
pub fn selu(&self) -> Result<Variable> {
let result = self.data().selu()?;
Ok(Variable::wrap(result))
}
pub fn hardswish(&self) -> Result<Variable> {
let result = self.data().hardswish()?;
Ok(Variable::wrap(result))
}
pub fn hardsigmoid(&self) -> Result<Variable> {
let result = self.data().hardsigmoid()?;
Ok(Variable::wrap(result))
}
pub fn prelu(&self, weight: &Variable) -> Result<Variable> {
let result = self.data().prelu(&weight.data())?;
Ok(Variable::wrap(result))
}
pub fn sum(&self) -> Result<Variable> {
let result = self.data().sum()?;
Ok(Variable::wrap(result))
}
pub fn mean(&self) -> Result<Variable> {
let result = self.data().mean()?;
Ok(Variable::wrap(result))
}
pub fn sum_dim(&self, dim: i32, keepdim: bool) -> Result<Variable> {
let result = self.data().sum_dim(dim, keepdim)?;
Ok(Variable::wrap(result))
}
pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Result<Variable> {
let result = self.data().mean_dim(dim, keepdim)?;
Ok(Variable::wrap(result))
}
pub fn prod(&self) -> Result<Variable> {
let result = self.data().prod()?;
Ok(Variable::wrap(result))
}
pub fn prod_dim(&self, dim: i32, keepdim: bool) -> Result<Variable> {
let result = self.data().prod_dim(dim, keepdim)?;
Ok(Variable::wrap(result))
}
pub fn cumsum(&self, dim: i32) -> Result<Variable> {
let result = self.data().cumsum(dim)?;
Ok(Variable::wrap(result))
}
pub fn logsumexp(&self, dim: i32, keepdim: bool) -> Result<Variable> {
let result = self.data().logsumexp(dim, keepdim)?;
Ok(Variable::wrap(result))
}
pub fn min(&self) -> Result<Variable> {
let result = self.data().min()?;
Ok(Variable::wrap(result))
}
pub fn max(&self) -> Result<Variable> {
let result = self.data().max()?;
Ok(Variable::wrap(result))
}
pub fn min_dim(&self, dim: i32, keepdim: bool) -> Result<Variable> {
let result = self.data().min_dim(dim, keepdim)?;
Ok(Variable::wrap(result))
}
pub fn max_dim(&self, dim: i32, keepdim: bool) -> Result<Variable> {
let result = self.data().max_dim(dim, keepdim)?;
Ok(Variable::wrap(result))
}
pub fn var(&self) -> Result<Variable> {
let result = self.data().var()?;
Ok(Variable::wrap(result))
}
pub fn std(&self) -> Result<Variable> {
let result = self.data().std()?;
Ok(Variable::wrap(result))
}
pub fn var_dim(&self, dim: i32, keepdim: bool) -> Result<Variable> {
let result = self.data().var_dim(dim, keepdim)?;
Ok(Variable::wrap(result))
}
pub fn std_dim(&self, dim: i32, keepdim: bool) -> Result<Variable> {
let result = self.data().std_dim(dim, keepdim)?;
Ok(Variable::wrap(result))
}
pub fn softmax(&self, dim: i32) -> Result<Variable> {
let result = self.data().softmax(dim)?;
Ok(Variable::wrap(result))
}
pub fn log_softmax(&self, dim: i32) -> Result<Variable> {
let result = self.data().log_softmax(dim)?;
Ok(Variable::wrap(result))
}
pub fn exp(&self) -> Result<Variable> {
let result = self.data().exp()?;
Ok(Variable::wrap(result))
}
pub fn log(&self) -> Result<Variable> {
let result = self.data().log()?;
Ok(Variable::wrap(result))
}
pub fn sqrt(&self) -> Result<Variable> {
let result = self.data().sqrt()?;
Ok(Variable::wrap(result))
}
pub fn abs(&self) -> Result<Variable> {
let result = self.data().abs()?;
Ok(Variable::wrap(result))
}
pub fn pow_scalar(&self, exponent: f64) -> Result<Variable> {
let result = self.data().pow_scalar(exponent)?;
Ok(Variable::wrap(result))
}
pub fn triu(&self, diagonal: i64) -> Result<Variable> {
let result = self.data().triu(diagonal)?;
Ok(Variable::wrap(result))
}
pub fn tril(&self, diagonal: i64) -> Result<Variable> {
let result = self.data().tril(diagonal)?;
Ok(Variable::wrap(result))
}
pub fn sin(&self) -> Result<Variable> {
let result = self.data().sin()?;
Ok(Variable::wrap(result))
}
pub fn cos(&self) -> Result<Variable> {
let result = self.data().cos()?;
Ok(Variable::wrap(result))
}
pub fn sign(&self) -> Result<Variable> {
let result = self.data().sign()?;
Ok(Variable::wrap(result))
}
pub fn floor(&self) -> Result<Variable> {
let result = self.data().floor()?;
Ok(Variable::wrap(result))
}
pub fn ceil(&self) -> Result<Variable> {
let result = self.data().ceil()?;
Ok(Variable::wrap(result))
}
pub fn round(&self) -> Result<Variable> {
let result = self.data().round()?;
Ok(Variable::wrap(result))
}
pub fn reciprocal(&self) -> Result<Variable> {
let result = self.data().reciprocal()?;
Ok(Variable::wrap(result))
}
pub fn clamp(&self, min: f64, max: f64) -> Result<Variable> {
let result = self.data().clamp(min, max)?;
Ok(Variable::wrap(result))
}
pub fn clamp_min(&self, min: f64) -> Result<Variable> {
let result = self.data().clamp_min(min)?;
Ok(Variable::wrap(result))
}
pub fn clamp_max(&self, max: f64) -> Result<Variable> {
let result = self.data().clamp_max(max)?;
Ok(Variable::wrap(result))
}
pub fn log1p(&self) -> Result<Variable> {
let result = self.data().log1p()?;
Ok(Variable::wrap(result))
}
pub fn expm1(&self) -> Result<Variable> {
let result = self.data().expm1()?;
Ok(Variable::wrap(result))
}
pub fn log2(&self) -> Result<Variable> {
let result = self.data().log2()?;
Ok(Variable::wrap(result))
}
pub fn log10(&self) -> Result<Variable> {
let result = self.data().log10()?;
Ok(Variable::wrap(result))
}
pub fn atan2(&self, other: &Variable) -> Result<Variable> {
let result = self.data().atan2(&other.data())?;
Ok(Variable::wrap(result))
}
pub fn maximum(&self, other: &Variable) -> Result<Variable> {
let result = self.data().maximum(&other.data())?;
Ok(Variable::wrap(result))
}
pub fn minimum(&self, other: &Variable) -> Result<Variable> {
let result = self.data().minimum(&other.data())?;
Ok(Variable::wrap(result))
}
pub fn masked_fill(&self, mask: &Tensor, value: f64) -> Result<Variable> {
let result = self.data().masked_fill(mask, value)?;
Ok(Variable::wrap(result))
}
pub fn normalize(&self, p: f64, dim: i32) -> Result<Variable> {
let result = self.data().normalize(p, dim)?;
Ok(Variable::wrap(result))
}
pub fn cosine_similarity(&self, other: &Variable, dim: i64, eps: f64) -> Result<Variable> {
let result = self.data().cosine_similarity(&other.data(), dim, eps)?;
Ok(Variable::wrap(result))
}
pub fn reshape(&self, shape: &[i64]) -> Result<Variable> {
let result = self.data().reshape(shape)?;
Ok(Variable::wrap(result))
}
pub fn transpose(&self, dim0: i32, dim1: i32) -> Result<Variable> {
let result = self.data().transpose(dim0, dim1)?;
Ok(Variable::wrap(result))
}
pub fn permute(&self, dims: &[i64]) -> Result<Variable> {
let result = self.data().permute(dims)?;
Ok(Variable::wrap(result))
}
pub fn squeeze(&self, dim: i32) -> Result<Variable> {
let result = self.data().squeeze(dim)?;
Ok(Variable::wrap(result))
}
pub fn unsqueeze(&self, dim: i32) -> Result<Variable> {
let result = self.data().unsqueeze(dim)?;
Ok(Variable::wrap(result))
}
pub fn unsqueeze_many(&self, dims: &[i32]) -> Result<Variable> {
let result = self.data().unsqueeze_many(dims)?;
Ok(Variable::wrap(result))
}
pub fn flatten(&self, start_dim: i32, end_dim: i32) -> Result<Variable> {
let result = self.data().flatten(start_dim, end_dim)?;
Ok(Variable::wrap(result))
}
pub fn expand(&self, shape: &[i64]) -> Result<Variable> {
let result = self.data().expand(shape)?;
Ok(Variable::wrap(result))
}
pub fn narrow(&self, dim: i32, start: i64, length: i64) -> Result<Variable> {
let result = self.data().narrow(dim, start, length)?;
Ok(Variable::wrap(result))
}
pub fn select(&self, dim: i32, index: i64) -> Result<Variable> {
let result = self.data().select(dim, index)?;
Ok(Variable::wrap(result))
}
pub fn index_select(&self, dim: i32, index: &Tensor) -> Result<Variable> {
let result = self.data().index_select(dim, index)?;
Ok(Variable::wrap(result))
}
pub fn gather(&self, dim: i32, index: &Tensor) -> Result<Variable> {
let result = self.data().gather(dim, index)?;
Ok(Variable::wrap(result))
}
pub fn cat(&self, other: &Variable, dim: i32) -> Result<Variable> {
let result = self.data().cat(&other.data(), dim)?;
Ok(Variable::wrap(result))
}
pub fn cat_many(vars: &[&Variable], dim: i32) -> Result<Variable> {
let tensors: Vec<Tensor> = vars.iter().map(|v| v.data()).collect();
let refs: Vec<&Tensor> = tensors.iter().collect();
let result = Tensor::cat_many(&refs, dim)?;
Ok(Variable::wrap(result))
}
pub fn stack(vars: &[Variable], dim: i32) -> Result<Variable> {
let tensors: Vec<Tensor> = vars.iter().map(|v| v.data()).collect();
let refs: Vec<&Tensor> = tensors.iter().collect();
let result = Tensor::stack(&refs, dim)?;
Ok(Variable::wrap(result))
}
pub fn chunk(&self, chunks: i32, dim: i32) -> Result<Vec<Variable>> {
let tensors = self.data().chunk(chunks, dim)?;
Ok(tensors.into_iter().map(Variable::wrap).collect())
}
pub fn repeat(&self, repeats: &[i64]) -> Result<Variable> {
let result = self.data().repeat(repeats)?;
Ok(Variable::wrap(result))
}
pub fn pad(&self, padding: &[i64], value: f64) -> Result<Variable> {
let result = self.data().pad(padding, value)?;
Ok(Variable::wrap(result))
}
pub fn topk(&self, k: i64, dim: i32, largest: bool, sorted: bool) -> Result<(Variable, Tensor)> {
let (values, indices) = self.data().topk(k, dim, largest, sorted)?;
Ok((Variable::wrap(values), indices))
}
pub fn sort(&self, dim: i32, descending: bool) -> Result<(Variable, Tensor)> {
let (values, indices) = self.data().sort(dim, descending)?;
Ok((Variable::wrap(values), indices))
}
}
pub fn linear(
input: &Variable,
weight: &Variable,
bias: Option<&Variable>,
) -> Result<Variable> {
let bias_tensor = bias.map(|b| b.data());
let result = input.data().linear(
&weight.data(),
bias_tensor.as_ref(),
)?;
Ok(Variable::wrap(result))
}
#[allow(clippy::too_many_arguments)]
pub fn gru_cell(
input: &Variable,
hx: &Variable,
w_ih: &Variable,
w_hh: &Variable,
b_ih: &Variable,
b_hh: &Variable,
) -> Result<Variable> {
let result = input.data().gru_cell(
&hx.data(),
&w_ih.data(), &w_hh.data(),
&b_ih.data(), &b_hh.data(),
)?;
Ok(Variable::wrap(result))
}
#[allow(clippy::too_many_arguments)]
pub fn lstm_cell(
input: &Variable,
hx: &Variable,
cx: &Variable,
w_ih: &Variable,
w_hh: &Variable,
b_ih: &Variable,
b_hh: &Variable,
) -> Result<(Variable, Variable)> {
let (h, c) = input.data().lstm_cell(
&hx.data(), &cx.data(),
&w_ih.data(), &w_hh.data(),
&b_ih.data(), &b_hh.data(),
)?;
Ok((Variable::wrap(h), Variable::wrap(c)))
}
pub fn layer_norm(
input: &Variable,
weight: &Variable,
bias: &Variable,
normalized_size: i64,
eps: f64,
) -> Result<Variable> {
let (output, _mean, _rstd) = input.data().native_layer_norm(
&weight.data(), &bias.data(), normalized_size, eps,
)?;
Ok(Variable::wrap(output))
}
pub fn conv2d(
input: &Variable,
weight: &Variable,
bias: Option<&Variable>,
stride: [i64; 2],
padding: [i64; 2],
dilation: [i64; 2],
groups: i64,
) -> Result<Variable> {
let bias_tensor = bias.map(|b| b.data());
let result = input.data().conv2d(
&weight.data(),
bias_tensor.as_ref(),
stride, padding, dilation, groups,
)?;
Ok(Variable::wrap(result))
}
#[allow(clippy::too_many_arguments)]
pub fn conv_transpose2d(
input: &Variable,
weight: &Variable,
bias: Option<&Variable>,
stride: [i64; 2],
padding: [i64; 2],
output_padding: [i64; 2],
dilation: [i64; 2],
groups: i64,
) -> Result<Variable> {
let bias_tensor = bias.map(|b| b.data());
let result = input.data().conv_transpose2d(
&weight.data(),
bias_tensor.as_ref(),
stride, padding, output_padding, dilation, groups,
)?;
Ok(Variable::wrap(result))
}
pub fn conv1d(
input: &Variable,
weight: &Variable,
bias: Option<&Variable>,
stride: i64,
padding: i64,
dilation: i64,
groups: i64,
) -> Result<Variable> {
let bias_tensor = bias.map(|b| b.data());
let result = input.data().conv1d(
&weight.data(),
bias_tensor.as_ref(),
stride, padding, dilation, groups,
)?;
Ok(Variable::wrap(result))
}
#[allow(clippy::too_many_arguments)]
pub fn conv_transpose1d(
input: &Variable,
weight: &Variable,
bias: Option<&Variable>,
stride: i64,
padding: i64,
output_padding: i64,
dilation: i64,
groups: i64,
) -> Result<Variable> {
let bias_tensor = bias.map(|b| b.data());
let result = input.data().conv_transpose1d(
&weight.data(),
bias_tensor.as_ref(),
stride, padding, output_padding, dilation, groups,
)?;
Ok(Variable::wrap(result))
}
pub fn group_norm(
input: &Variable,
num_groups: i64,
weight: &Variable,
bias: &Variable,
eps: f64,
) -> Result<Variable> {
let result = input.data().group_norm(
num_groups,
Some(&weight.data()),
Some(&bias.data()),
eps,
)?;
Ok(Variable::wrap(result))
}
pub fn max_pool2d(
input: &Variable,
kernel_size: [i64; 2],
stride: [i64; 2],
padding: [i64; 2],
dilation: [i64; 2],
ceil_mode: bool,
) -> Result<Variable> {
let result = input.data().max_pool2d(kernel_size, stride, padding, dilation, ceil_mode)?;
Ok(Variable::wrap(result))
}
pub fn avg_pool2d(
input: &Variable,
kernel_size: [i64; 2],
stride: [i64; 2],
padding: [i64; 2],
ceil_mode: bool,
count_include_pad: bool,
) -> Result<Variable> {
let result = input.data().avg_pool2d(kernel_size, stride, padding, ceil_mode, count_include_pad)?;
Ok(Variable::wrap(result))
}
pub fn adaptive_avg_pool2d(
input: &Variable,
output_size: [i64; 2],
) -> Result<Variable> {
let result = input.data().adaptive_avg_pool2d(output_size)?;
Ok(Variable::wrap(result))
}
pub fn im2col(
input: &Variable,
kernel_size: [i64; 2],
dilation: [i64; 2],
padding: [i64; 2],
stride: [i64; 2],
) -> Result<Variable> {
let result = input.data().im2col(kernel_size, dilation, padding, stride)?;
Ok(Variable::wrap(result))
}
pub fn col2im(
input: &Variable,
output_size: [i64; 2],
kernel_size: [i64; 2],
dilation: [i64; 2],
padding: [i64; 2],
stride: [i64; 2],
) -> Result<Variable> {
let result = input.data().col2im(output_size, kernel_size, dilation, padding, stride)?;
Ok(Variable::wrap(result))
}
#[allow(clippy::too_many_arguments)]
pub fn conv3d(
input: &Variable,
weight: &Variable,
bias: Option<&Variable>,
stride: [i64; 3],
padding: [i64; 3],
dilation: [i64; 3],
groups: i64,
) -> Result<Variable> {
let bias_tensor = bias.map(|b| b.data());
let result = input.data().conv3d(
&weight.data(),
bias_tensor.as_ref(),
stride, padding, dilation, groups,
)?;
Ok(Variable::wrap(result))
}
#[allow(clippy::too_many_arguments)]
pub fn conv_transpose3d(
input: &Variable,
weight: &Variable,
bias: Option<&Variable>,
stride: [i64; 3],
padding: [i64; 3],
output_padding: [i64; 3],
dilation: [i64; 3],
groups: i64,
) -> Result<Variable> {
let bias_tensor = bias.map(|b| b.data());
let result = input.data().conv_transpose3d(
&weight.data(),
bias_tensor.as_ref(),
stride, padding, output_padding, dilation, groups,
)?;
Ok(Variable::wrap(result))
}
pub fn max_pool1d(
input: &Variable,
kernel_size: i64,
stride: i64,
padding: i64,
dilation: i64,
ceil_mode: bool,
) -> Result<Variable> {
let result = input.data().max_pool1d(kernel_size, stride, padding, dilation, ceil_mode)?;
Ok(Variable::wrap(result))
}
pub fn avg_pool1d(
input: &Variable,
kernel_size: i64,
stride: i64,
padding: i64,
ceil_mode: bool,
count_include_pad: bool,
) -> Result<Variable> {
let result = input.data().avg_pool1d(kernel_size, stride, padding, ceil_mode, count_include_pad)?;
Ok(Variable::wrap(result))
}
pub fn adaptive_max_pool2d(
input: &Variable,
output_size: [i64; 2],
) -> Result<Variable> {
let result = input.data().adaptive_max_pool2d(output_size)?;
Ok(Variable::wrap(result))
}
#[allow(clippy::too_many_arguments)]
pub fn instance_norm(
input: &Variable,
weight: Option<&Variable>,
bias: Option<&Variable>,
running_mean: Option<&Tensor>,
running_var: Option<&Tensor>,
use_input_stats: bool,
momentum: f64,
eps: f64,
) -> Result<Variable> {
let w = weight.map(|v| v.data());
let b = bias.map(|v| v.data());
let result = input.data().instance_norm(
w.as_ref(), b.as_ref(),
running_mean, running_var,
use_input_stats, momentum, eps,
)?;
Ok(Variable::wrap(result))
}
pub fn pixel_shuffle(input: &Variable, upscale_factor: i64) -> Result<Variable> {
let result = input.data().pixel_shuffle(upscale_factor)?;
Ok(Variable::wrap(result))
}
pub fn pixel_unshuffle(input: &Variable, downscale_factor: i64) -> Result<Variable> {
let result = input.data().pixel_unshuffle(downscale_factor)?;
Ok(Variable::wrap(result))
}
pub fn bilinear(
input1: &Variable,
input2: &Variable,
weight: &Variable,
bias: Option<&Variable>,
) -> Result<Variable> {
let b = bias.map(|v| v.data());
let result = Tensor::bilinear(
&input1.data(), &input2.data(), &weight.data(), b.as_ref(),
)?;
Ok(Variable::wrap(result))
}
pub fn grid_sample(
input: &Variable,
grid: &Variable,
mode: i32,
padding_mode: i32,
align_corners: bool,
) -> Result<Variable> {
let result = input.data().grid_sample(
&grid.data(), mode, padding_mode, align_corners,
)?;
Ok(Variable::wrap(result))
}
pub fn scaled_dot_product_attention(
query: &Variable,
key: &Variable,
value: &Variable,
attn_mask: Option<&Tensor>,
dropout_p: f64,
is_causal: bool,
scale: Option<f64>,
) -> Result<Variable> {
let result = Tensor::scaled_dot_product_attention(
&query.data(), &key.data(), &value.data(),
attn_mask,
dropout_p, is_causal, scale,
)?;
Ok(Variable::wrap(result))
}
pub fn embedding(
weight: &Variable,
indices: &Tensor,
padding_idx: i64,
) -> Result<Variable> {
let result = Tensor::embedding(&weight.data(), indices, padding_idx)?;
Ok(Variable::wrap(result))
}
pub fn embedding_bag(
weight: &Variable,
indices: &Tensor,
offsets: &Tensor,
mode: i64,
) -> Result<Variable> {
let result = Tensor::embedding_bag(&weight.data(), indices, offsets, mode)?;
Ok(Variable::wrap(result))
}