use std::collections::VecDeque;
use std::time::Duration;
use async_trait::async_trait;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use crate::backend::audio_generation_backend::{
AudioGenerationRequest, BackendOutboundMsg, JobProcessor,
};
use crate::backend::audio_generation_fanout::{
AudioGenerationError, AudioGenerationProgress, AudioGenerationResult, AudioGenerationStart,
GenerationMessage,
};
use crate::backend::music_gpt_chat::{Chat, ChatEntry};
use crate::backend::music_gpt_ws_handler::{Info, OutboundMsg};
use crate::storage::AppFs;
impl OutboundMsg {
pub(crate) fn info(self) -> Info {
match self {
OutboundMsg::Info(p) => p,
_ => panic!("msg was not OutboundMsg::Init, it was {self:?}"),
}
}
pub(crate) fn chats(self) -> Vec<Chat> {
match self {
OutboundMsg::Chats(p) => p,
_ => panic!("msg was not OutboundMsg::Chats, it was {self:?}"),
}
}
pub(crate) fn start(self) -> AudioGenerationStart {
match self {
OutboundMsg::Generation(GenerationMessage::Start(p)) => p,
_ => panic!("msg was not GenerationMessage::Start, it was {self:?}"),
}
}
pub(crate) fn progress(self) -> AudioGenerationProgress {
match self {
OutboundMsg::Generation(GenerationMessage::Progress(p)) => p,
_ => panic!("msg was not GenerationMessage::Progress, it was {self:?}"),
}
}
pub(crate) fn result(self) -> AudioGenerationResult {
match self {
OutboundMsg::Generation(GenerationMessage::Result(p)) => p,
_ => panic!("msg was not GenerationMessage::Result, it was {self:?}"),
}
}
pub(crate) fn error(self) -> AudioGenerationError {
match self {
OutboundMsg::Generation(GenerationMessage::Error(p)) => p,
_ => panic!("msg was not GenerationMessage::Error, it was {self:?}"),
}
}
pub(crate) fn chat(self) -> (Chat, Vec<ChatEntry>) {
match self {
OutboundMsg::Chat(p) => p,
_ => panic!("msg was not GenerationMessage::Chat, it was {self:?}"),
}
}
}
impl BackendOutboundMsg {
pub(crate) fn unwrap_start(self) -> AudioGenerationRequest {
match self {
BackendOutboundMsg::Start(p) => p,
_ => panic!("msg was not Progress, it was {self:?}"),
}
}
pub(crate) fn unwrap_progress(self) -> (String, f32) {
match self {
BackendOutboundMsg::Progress(p) => p,
_ => panic!("msg was not Progress, it was {self:?}"),
}
}
pub(crate) fn unwrap_response(self) -> (String, VecDeque<f32>) {
match self {
BackendOutboundMsg::Response(p) => p,
_ => panic!("msg was not Response, it was {self:?}"),
}
}
pub(crate) fn unwrap_err(self) -> (String, String) {
match self {
BackendOutboundMsg::Failure(p) => p,
_ => panic!("msg was not Failure, it was {self:?}"),
}
}
}
#[derive(Default)]
pub struct DummyJobProcessor {
wait_scale: Duration,
}
impl DummyJobProcessor {
pub fn new(wait_scale: Duration) -> Self {
Self { wait_scale }
}
}
#[async_trait]
impl JobProcessor for DummyJobProcessor {
fn process(
&self,
prompt: &str,
secs: usize,
on_progress: Box<dyn Fn(f32, f32) -> bool + Sync + Send + 'static>,
) -> ort::Result<VecDeque<f32>> {
let mut result = VecDeque::new();
for i in 0..secs {
if prompt == format!("fail at {i}") {
return Err(ort::Error::new(format!("Failed at {i}")));
}
std::thread::sleep(self.wait_scale);
result.push_back(i as f32);
let should_exit = on_progress(result.len() as f32, secs as f32);
if should_exit {
return Err(ort::Error::new("Aborted"));
}
}
Ok(result)
}
}
pub fn rand_string() -> String {
thread_rng()
.sample_iter(&Alphanumeric)
.take(7)
.map(char::from)
.collect()
}
impl AppFs {
pub fn new_tmp() -> Self {
Self::new(format!("/tmp/musicgpt-tests/{}", rand_string()))
}
}