musicgpt 0.3.1

Generate music samples from natural language prompt locally with your own computer
use axum::extract::WebSocketUpgrade;
use axum::response::Html;
use axum::routing::get;
use axum::Router;
use tower_http::services::ServeDir;

use crate::storage::AppFs;
use crate::backend::audio_generation_fanout::audio_generation_fanout;
use crate::backend::audio_generation_backend::{AudioGenerationBackend, JobProcessor};
use crate::backend::ws_handler::WsHandler;
use crate::backend::music_gpt_ws_handler::{Info, MusicGptWsHandler};

async fn web_app() -> Html<&'static str> {
    Html(include_str!(concat!(
        env!("CARGO_MANIFEST_DIR"),
        "/web/dist/index.html"
    )))
}

pub async fn run<T: JobProcessor + 'static>(
    storage: AppFs,
    processor: T,
    port: usize,
    open: bool,
) -> anyhow::Result<()> {
    let model = processor.name();
    let device = processor.device();

    let (ai_tx, ai_rx) = AudioGenerationBackend::new(processor).run();
    let ai_broadcast_tx = audio_generation_fanout(ai_rx, storage.clone());

    let root_dir = storage.root.clone();
    let ws_handler = MusicGptWsHandler {
        ai_tx,
        storage,
        info: Info { model, device },
        ai_broadcast_tx,
    };

    let app = Router::new()
        .route("/", get(web_app))
        .nest_service("/files", ServeDir::new(root_dir))
        .route(
            "/ws",
            get(|ws: WebSocketUpgrade| async move {
                let ws_handler = ws_handler.clone();
                ws.on_upgrade(move |ws| ws_handler.handle(ws))
            }),
        );

    let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{port}")).await?;
    if open {
        let _ = open::that(format!("http://localhost:{port}"));
    }

    Ok(axum::serve(listener, app).await?)
}

#[cfg(test)]
mod tests {
    use std::sync::atomic::{AtomicU16, Ordering};
    use std::time::Duration;

    use futures_util::{SinkExt, StreamExt};
    use serde::de::DeserializeOwned;
    use serde::Serialize;
    use tokio_tungstenite::connect_async;
    use uuid::Uuid;

    use crate::backend::_test_utils::DummyJobProcessor;
    use crate::backend::music_gpt_ws_handler::{
        AbortGenerationRequest, GenerateAudioRequest, InboundMsg, OutboundMsg,
    };

    use super::*;

    static PORT: AtomicU16 = AtomicU16::new(8643);

    fn spawn<P: JobProcessor + 'static>(processor: P) -> usize {
        let app_fs = AppFs::new_tmp();
        let port = PORT.fetch_add(1, Ordering::SeqCst) as usize;
        tokio::spawn(run(app_fs, processor, port, false));
        port
    }

    #[tokio::test]
    async fn sending_a_job_processes_it() -> anyhow::Result<()> {
        let port = spawn(DummyJobProcessor::default());

        let (mut ws_stream, _) = connect_async(&format!("ws://localhost:{port}/ws")).await?;
        let id = Uuid::new_v4();
        let chat_id = Uuid::new_v4();
        let msg = InboundMsg::GenerateAudio(GenerateAudioRequest {
            id,
            chat_id,
            prompt: "Create a cool song".to_string(),
            secs: 4,
        });
        ws_stream.send(msg.to_tungstenite_msg()).await?;

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        msg.unwrap_info();

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        msg.unwrap_chats();
        
        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        msg.unwrap_start();

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        let p = msg.unwrap_progress();
        assert_eq!(p.id, id);
        assert_eq!(p.chat_id, chat_id);
        assert_eq!(p.progress, 0.25);

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        let p = msg.unwrap_progress();
        assert_eq!(p.id, id);
        assert_eq!(p.chat_id, chat_id);
        assert_eq!(p.progress, 0.5);

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        let p = msg.unwrap_progress();
        assert_eq!(p.id, id);
        assert_eq!(p.chat_id, chat_id);
        assert_eq!(p.progress, 0.75);

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        let p = msg.unwrap_progress();
        assert_eq!(p.id, id);
        assert_eq!(p.chat_id, chat_id);
        assert_eq!(p.progress, 1.0);

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        let p = msg.unwrap_result();
        assert_eq!(p.id, id);
        assert_eq!(p.chat_id, chat_id);
        assert_eq!(p.relpath, format!("audios/{id}.wav"));
        
        let res = reqwest::get(format!("http://localhost:{port}/files/audios/{id}.wav")).await?;
        assert_eq!(res.status(), 200);
        
        Ok(())
    }

    #[tokio::test]
    async fn can_abort_a_job() -> anyhow::Result<()> {
        let port = spawn(DummyJobProcessor::new(Duration::from_millis(100)));

        let (mut ws_stream, _) = connect_async(&format!("ws://localhost:{port}/ws")).await?;
        let id = Uuid::new_v4();
        let chat_id = Uuid::new_v4();
        let msg = InboundMsg::GenerateAudio(GenerateAudioRequest {
            id,
            chat_id,
            prompt: "Create a cool song".to_string(),
            secs: 4,
        });
        ws_stream.send(msg.to_tungstenite_msg()).await?;

        tokio::time::sleep(Duration::from_millis(150)).await;
        ws_stream
            .send(InboundMsg::AbortGeneration(AbortGenerationRequest { id, chat_id }).to_tungstenite_msg())
            .await?;

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        msg.unwrap_info();

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        msg.unwrap_chats();

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        msg.unwrap_start();

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        let p = msg.unwrap_progress();
        assert_eq!(p.id, id);
        assert_eq!(p.chat_id, chat_id);
        assert_eq!(p.progress, 0.25);

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        let p = msg.unwrap_progress();
        assert_eq!(p.id, id);
        assert_eq!(p.chat_id, chat_id);
        assert_eq!(p.progress, 0.5);

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        let p = msg.unwrap_err();
        assert_eq!(p.id, id);
        assert_eq!(p.chat_id, chat_id);
        assert_eq!(p.error, "Aborted");

        Ok(())
    }

    #[tokio::test]
    async fn handles_job_failures() -> anyhow::Result<()> {
        let port = spawn(DummyJobProcessor::default());

        let (mut ws_stream, _) = connect_async(&format!("ws://localhost:{port}/ws")).await?;
        let id = Uuid::new_v4();
        let chat_id = Uuid::new_v4();
        let msg = InboundMsg::GenerateAudio(GenerateAudioRequest {
            id,
            chat_id,
            prompt: "fail at 2".to_string(),
            secs: 4,
        });
        ws_stream.send(msg.to_tungstenite_msg()).await?;

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        msg.unwrap_info();

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        msg.unwrap_chats();

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        msg.unwrap_start();

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        let p = msg.unwrap_progress();
        assert_eq!(p.id, id);
        assert_eq!(p.chat_id, chat_id);
        assert_eq!(p.progress, 0.25);

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        let p = msg.unwrap_progress();
        assert_eq!(p.id, id);
        assert_eq!(p.chat_id, chat_id);
        assert_eq!(p.progress, 0.5);

        let msg = OutboundMsg::from_tungstenite_msg(ws_stream.next().await.unwrap()?)?;
        let p = msg.unwrap_err();
        assert_eq!(p.id, id);
        assert_eq!(p.chat_id, chat_id);
        assert_eq!(p.error, "Failed at 2");

        Ok(())
    }

    trait TungsteniteMsg: Sized {
        fn to_tungstenite_msg(&self) -> tokio_tungstenite::tungstenite::Message;
        fn from_tungstenite_msg(
            msg: tokio_tungstenite::tungstenite::Message,
        ) -> anyhow::Result<Self>;
    }

    impl<T: Serialize + DeserializeOwned> TungsteniteMsg for T {
        fn to_tungstenite_msg(&self) -> tokio_tungstenite::tungstenite::Message {
            tokio_tungstenite::tungstenite::Message::Text(
                serde_json::to_string(self).expect("Could not serialize msg"),
            )
        }

        fn from_tungstenite_msg(
            msg: tokio_tungstenite::tungstenite::Message,
        ) -> anyhow::Result<Self> {
            let msg = msg.to_text()?;
            Ok(serde_json::de::from_str(msg)?)
        }
    }
}