musicgpt 0.2.1

Generate music samples from natural language prompt locally with your own computer
use std::collections::HashMap;

pub struct MusicGenInputs {
    inputs: HashMap<String, ort::DynValue>,
    pub use_cache_branch: bool,
}

impl MusicGenInputs {
    pub fn new() -> Self {
        Self {
            inputs: HashMap::new(),
            use_cache_branch: false,
        }
    }
    pub fn encoder_attention_mask<T, E>(&mut self, v: T) -> Result<(), E>
    where
        ort::DynValue: TryFrom<T, Error = E>,
    {
        self.inputs
            .insert("encoder_attention_mask".to_string(), v.try_into()?);
        Ok(())
    }

    pub fn input_ids<T, E>(&mut self, v: T) -> Result<(), E>
    where
        ort::DynValue: TryFrom<T, Error = E>,
    {
        self.inputs.insert("input_ids".to_string(), v.try_into()?);
        Ok(())
    }

    pub fn encoder_hidden_states<T, E>(&mut self, v: T) -> Result<(), E>
    where
        ort::DynValue: TryFrom<T, Error = E>,
    {
        self.inputs
            .insert("encoder_hidden_states".to_string(), v.try_into()?);
        Ok(())
    }
    
    pub fn remove_encoder_hidden_states(&mut self) {
        self.inputs.remove("encoder_hidden_states");
    }

    pub fn past_key_value_decoder_key<T, E>(&mut self, i: usize, v: T) -> Result<(), E>
    where
        ort::DynValue: TryFrom<T, Error = E>,
    {
        self.inputs
            .insert(format!("past_key_values.{i}.decoder.key"), v.try_into()?);
        Ok(())
    }

    pub fn past_key_value_decoder_value<T, E>(&mut self, i: usize, v: T) -> Result<(), E>
    where
        ort::DynValue: TryFrom<T, Error = E>,
    {
        self.inputs
            .insert(format!("past_key_values.{i}.decoder.value"), v.try_into()?);
        Ok(())
    }

    pub fn past_key_value_encoder_key<T, E>(&mut self, i: usize, v: T) -> Result<(), E>
    where
        ort::DynValue: TryFrom<T, Error = E>,
    {
        self.inputs
            .insert(format!("past_key_values.{i}.encoder.key"), v.try_into()?);
        Ok(())
    }

    pub fn past_key_value_encoder_value<T, E>(&mut self, i: usize, v: T) -> Result<(), E>
    where
        ort::DynValue: TryFrom<T, Error = E>,
    {
        self.inputs
            .insert(format!("past_key_values.{i}.encoder.value"), v.try_into()?);
        Ok(())
    }

    pub fn use_cache_branch(&mut self, value: bool) {
        self.use_cache_branch = value;
        self.inputs.insert(
            "use_cache_branch".to_string(),
            ort::Tensor::from_array(([1], vec![value]))
                .unwrap()
                .into_dyn(),
        );
    }

    pub fn ort(&self) -> ort::SessionInputs {
        ort::SessionInputs::ValueMap(
            self.inputs
                .iter()
                .map(|e| (e.0.to_string().into(), e.1.view().into()))
                .collect::<Vec<_>>(),
        )
    }
}