musicgpt 0.3.0

Generate music samples from natural language prompt locally with your own computer
use crate::logits::Logits;

pub struct MusicGenOutputs<'s> {
    outputs: ort::SessionOutputs<'s>,
}

impl<'s> MusicGenOutputs<'s> {
    pub fn new(outputs: ort::SessionOutputs<'s>) -> Self {
        Self { outputs }
    }

    pub fn take_logits(&mut self) -> ort::Result<Logits> {
        Logits::from_3d_dyn_value(&self.outputs.remove("logits").unwrap())
    }

    pub fn take_present_decoder_key(&mut self, i: usize) -> ort::DynValue {
        let key = format!("present.{i}.decoder.key");
        self.outputs
            .remove(key.as_str())
            .unwrap_or_else(|| panic!("{key} was already taken from outputs"))
    }

    pub fn take_present_decoder_value(&mut self, i: usize) -> ort::DynValue {
        let value = format!("present.{i}.decoder.value");
        self.outputs
            .remove(value.as_str())
            .unwrap_or_else(|| panic!("{value} was already taken from outputs"))
    }

    pub fn take_present_encoder_key(&mut self, i: usize) -> ort::DynValue {
        let key = format!("present.{i}.encoder.key");
        self.outputs
            .remove(key.as_str())
            .unwrap_or_else(|| panic!("{key} was already taken from outputs"))
    }

    pub fn take_present_encoder_value(&mut self, i: usize) -> ort::DynValue {
        let value = format!("present.{i}.encoder.value");
        self.outputs
            .remove(value.as_str())
            .unwrap_or_else(|| panic!("{value} was already taken from outputs"))
    }
}