musicgpt 0.3.26

Generate music based on natural language prompts using LLMs running locally
use std::collections::VecDeque;
use std::time::Duration;

use async_trait::async_trait;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};

use crate::backend::audio_generation_backend::{
    AudioGenerationRequest, BackendOutboundMsg, JobProcessor,
};
use crate::backend::audio_generation_fanout::{
    AudioGenerationError, AudioGenerationProgress, AudioGenerationResult, AudioGenerationStart,
    GenerationMessage,
};
use crate::backend::music_gpt_chat::{Chat, ChatEntry};
use crate::backend::music_gpt_ws_handler::{Info, OutboundMsg};
use crate::storage::AppFs;

impl OutboundMsg {
    pub(crate) fn info(self) -> Info {
        match self {
            OutboundMsg::Info(p) => p,
            _ => panic!("msg was not OutboundMsg::Init, it was {self:?}"),
        }
    }

    pub(crate) fn chats(self) -> Vec<Chat> {
        match self {
            OutboundMsg::Chats(p) => p,
            _ => panic!("msg was not OutboundMsg::Chats, it was {self:?}"),
        }
    }

    pub(crate) fn start(self) -> AudioGenerationStart {
        match self {
            OutboundMsg::Generation(GenerationMessage::Start(p)) => p,
            _ => panic!("msg was not GenerationMessage::Start, it was {self:?}"),
        }
    }

    pub(crate) fn progress(self) -> AudioGenerationProgress {
        match self {
            OutboundMsg::Generation(GenerationMessage::Progress(p)) => p,
            _ => panic!("msg was not GenerationMessage::Progress, it was {self:?}"),
        }
    }

    pub(crate) fn result(self) -> AudioGenerationResult {
        match self {
            OutboundMsg::Generation(GenerationMessage::Result(p)) => p,
            _ => panic!("msg was not GenerationMessage::Result, it was {self:?}"),
        }
    }

    pub(crate) fn error(self) -> AudioGenerationError {
        match self {
            OutboundMsg::Generation(GenerationMessage::Error(p)) => p,
            _ => panic!("msg was not GenerationMessage::Error, it was {self:?}"),
        }
    }

    pub(crate) fn chat(self) -> (Chat, Vec<ChatEntry>) {
        match self {
            OutboundMsg::Chat(p) => p,
            _ => panic!("msg was not GenerationMessage::Chat, it was {self:?}"),
        }
    }
}

impl BackendOutboundMsg {
    pub(crate) fn unwrap_start(self) -> AudioGenerationRequest {
        match self {
            BackendOutboundMsg::Start(p) => p,
            _ => panic!("msg was not Progress, it was {self:?}"),
        }
    }

    pub(crate) fn unwrap_progress(self) -> (String, f32) {
        match self {
            BackendOutboundMsg::Progress(p) => p,
            _ => panic!("msg was not Progress, it was {self:?}"),
        }
    }

    pub(crate) fn unwrap_response(self) -> (String, VecDeque<f32>) {
        match self {
            BackendOutboundMsg::Response(p) => p,
            _ => panic!("msg was not Response, it was {self:?}"),
        }
    }

    pub(crate) fn unwrap_err(self) -> (String, String) {
        match self {
            BackendOutboundMsg::Failure(p) => p,
            _ => panic!("msg was not Failure, it was {self:?}"),
        }
    }
}

#[derive(Default)]
pub struct DummyJobProcessor {
    wait_scale: Duration,
}

impl DummyJobProcessor {
    pub fn new(wait_scale: Duration) -> Self {
        Self { wait_scale }
    }
}

#[async_trait]
impl JobProcessor for DummyJobProcessor {
    fn process(
        &self,
        prompt: &str,
        secs: usize,
        on_progress: Box<dyn Fn(f32, f32) -> bool + Sync + Send + 'static>,
    ) -> ort::Result<VecDeque<f32>> {
        let mut result = VecDeque::new();
        for i in 0..secs {
            if prompt == format!("fail at {i}") {
                return Err(ort::Error::new(format!("Failed at {i}")));
            }
            std::thread::sleep(self.wait_scale);
            result.push_back(i as f32);
            let should_exit = on_progress(result.len() as f32, secs as f32);
            if should_exit {
                return Err(ort::Error::new("Aborted"));
            }
        }

        Ok(result)
    }
}

pub fn rand_string() -> String {
    thread_rng()
        .sample_iter(&Alphanumeric)
        .take(7)
        .map(char::from)
        .collect()
}

impl AppFs {
    pub fn new_tmp() -> Self {
        Self::new(format!("/tmp/musicgpt-tests/{}", rand_string()))
    }
}