use std::collections::VecDeque;
use std::sync::mpsc::{channel, Receiver, Sender};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio_util::sync::CancellationToken;
#[derive(Clone, Debug)]
pub struct AudioGenerationRequest {
pub id: String,
pub prompt: String,
pub secs: usize,
}
#[derive(Clone, Debug)]
pub enum BackendInboundMsg {
Request(AudioGenerationRequest),
Abort(String),
}
#[derive(Clone, Debug)]
pub enum BackendOutboundMsg {
Start(AudioGenerationRequest),
Response((String, VecDeque<f32>)),
Failure((String, String)),
Progress((String, f32)),
}
#[derive(Clone, Debug)]
struct Job {
req: AudioGenerationRequest,
abort_token: CancellationToken,
}
impl Job {
fn new(req: AudioGenerationRequest) -> Self {
Self {
req,
abort_token: CancellationToken::new(),
}
}
}
pub trait JobProcessor: Send + Sync {
fn process(
&self,
prompt: &str,
secs: usize,
on_progress: Box<dyn Fn(f32, f32) -> bool + Sync + Send + 'static>,
) -> ort::Result<VecDeque<f32>>;
}
#[derive(Clone)]
pub struct AudioGenerationBackend {
processor: Arc<dyn JobProcessor>,
job_queue: Arc<RwLock<VecDeque<Job>>>,
abort_token: CancellationToken,
}
impl AudioGenerationBackend {
pub fn new<T: JobProcessor + 'static>(processor: T) -> Self {
Self {
processor: Arc::new(processor),
job_queue: Arc::new(RwLock::new(VecDeque::new())),
abort_token: CancellationToken::new(),
}
}
fn job_processing_loop(self, outbound_tx: Sender<BackendOutboundMsg>) {
loop {
let front = {
let jq = self.job_queue.read().unwrap();
jq.front().cloned()
};
let Some(job) = front else {
if self.abort_token.is_cancelled() {
return;
}
std::thread::sleep(Duration::from_millis(10));
continue;
};
let _ = outbound_tx.send(BackendOutboundMsg::Start(job.req.clone()));
let output_tx_clone = outbound_tx.clone();
let abort_token = self.abort_token.clone();
let job_id = job.req.id.clone();
let cbk = Box::new(move |elapsed, total| {
let msg = BackendOutboundMsg::Progress((job_id.clone(), elapsed / total));
let _ = output_tx_clone.send(msg);
abort_token.is_cancelled() || job.abort_token.is_cancelled()
});
let msg = match self.processor.process(&job.req.prompt, job.req.secs, cbk) {
Ok(filepath) => BackendOutboundMsg::Response((job.req.id, filepath)),
Err(err) => BackendOutboundMsg::Failure((job.req.id, err.to_string())),
};
let _ = outbound_tx.send(msg);
self.job_queue.write().unwrap().pop_front();
}
}
fn msg_processing_loop(self, inbound_rx: Receiver<BackendInboundMsg>) {
while let Ok(msg) = inbound_rx.recv() {
match msg {
BackendInboundMsg::Request(req) => {
self.job_queue.write().unwrap().push_back(Job::new(req));
}
BackendInboundMsg::Abort(id) => {
let mut queue = self.job_queue.write().unwrap();
let mut to_remove = None;
for (i, job) in queue.iter().enumerate() {
if job.req.id == id {
to_remove = Some(i);
job.abort_token.cancel();
break;
}
}
if let Some(to_remove) = to_remove {
queue.remove(to_remove);
}
}
}
}
self.abort_token.cancel()
}
pub fn run(self) -> (Sender<BackendInboundMsg>, Receiver<BackendOutboundMsg>) {
let (inbound_tx, inbound_rx) = channel::<BackendInboundMsg>();
let (outbound_tx, outbound_rx) = channel::<BackendOutboundMsg>();
let self_clone = self.clone();
std::thread::spawn(move || self_clone.job_processing_loop(outbound_tx));
std::thread::spawn(move || self.msg_processing_loop(inbound_rx));
(inbound_tx, outbound_rx)
}
}
#[cfg(test)]
mod tests {
use uuid::Uuid;
use crate::backend::_test_utils::DummyJobProcessor;
use super::*;
#[test]
fn processes_job() -> anyhow::Result<()> {
let backend = AudioGenerationBackend::new(DummyJobProcessor::default());
let (tx, rx) = backend.run();
let id = Uuid::new_v4().to_string();
tx.send(BackendInboundMsg::Request(AudioGenerationRequest {
id: id.clone(),
prompt: "".to_string(),
secs: 4,
}))?;
assert_eq!(rx.recv()?.unwrap_start().id, id);
assert_eq!(rx.recv()?.unwrap_progress().1, 0.25);
assert_eq!(rx.recv()?.unwrap_progress().1, 0.5);
assert_eq!(rx.recv()?.unwrap_progress().1, 0.75);
assert_eq!(rx.recv()?.unwrap_progress().1, 1.0);
assert_eq!(
rx.recv()?.unwrap_response().1,
VecDeque::from([0.0, 1.0, 2.0, 3.0])
);
Ok(())
}
#[test]
fn handles_job_failure() -> anyhow::Result<()> {
let backend = AudioGenerationBackend::new(DummyJobProcessor::default());
let (tx, rx) = backend.run();
let id = Uuid::new_v4().to_string();
tx.send(BackendInboundMsg::Request(AudioGenerationRequest {
id: id.clone(),
prompt: "fail at 2".to_string(),
secs: 4,
}))?;
assert_eq!(rx.recv()?.unwrap_start().id, id);
assert_eq!(rx.recv()?.unwrap_progress().1, 0.25);
assert_eq!(rx.recv()?.unwrap_progress().1, 0.5);
assert_eq!(rx.recv()?.unwrap_err().1, "Failed at 2");
Ok(())
}
#[tokio::test]
#[cfg(not(target_os = "macos"))]
async fn handles_job_cancellation() -> anyhow::Result<()> {
let backend =
AudioGenerationBackend::new(DummyJobProcessor::new(Duration::from_millis(200)));
let (tx, rx) = backend.run();
let id = Uuid::new_v4().to_string();
tx.send(BackendInboundMsg::Request(AudioGenerationRequest {
id: id.clone(),
prompt: "".to_string(),
secs: 4,
}))?;
tokio::time::sleep(Duration::from_millis(50)).await;
tx.send(BackendInboundMsg::Abort(id.clone()))?;
assert_eq!(rx.recv()?.unwrap_start().id, id);
assert_eq!(rx.recv()?.unwrap_progress().1, 0.25);
assert_eq!(rx.recv()?.unwrap_err().1, "Aborted");
let id = Uuid::new_v4().to_string();
tx.send(BackendInboundMsg::Request(AudioGenerationRequest {
id: id.clone(),
prompt: "".to_string(),
secs: 1,
}))?;
assert_eq!(rx.recv()?.unwrap_start().id, id);
assert_eq!(rx.recv()?.unwrap_progress().1, 1.0);
Ok(())
}
}