musicgpt 0.3.0

Generate music samples from natural language prompt locally with your own computer
use std::fmt::Debug;
use ndarray::Array;
use num_traits::{One, Zero};

pub fn zeros_tensor<T: ort::IntoTensorElementType + Debug + Clone + Zero + 'static>(
    shape: &[usize],
) -> ort::Tensor<T> {
    ort::Value::from_array(Array::<T, _>::zeros(shape)).expect("Could not build zeros tensor")
}

pub fn dupe_zeros_along_first_dim<T: ort::IntoTensorElementType + Debug + Zero + Clone + 'static>(
    tensor: ort::Tensor<T>,
) -> ort::Result<ort::Tensor<T>> {
    let (mut shape, data) = tensor.try_extract_raw_tensor()?;
    shape[0] *= 2;
    let data = [data.to_vec(), vec![T::zero(); data.len()]].concat();
    ort::Tensor::from_array((shape, data))
}

pub fn ones_tensor<T: ort::IntoTensorElementType + Debug + Clone + One + 'static>(
    shape: &[usize],
) -> ort::Tensor<T> {
    ort::Value::from_array(Array::<T, _>::ones(shape)).expect("Could not build zeros tensor")
}