use std::collections::VecDeque;
use derive_more::IsVariant;
use super::{
encoder::{StreamingEncoder, StreamingEncoderBackend},
mel_spectrogram::IncrementalMelSpectrogram,
};
use crate::{
Array,
error::{Error, LayerKeyedPayload, Result},
};
#[derive(Debug)]
pub(super) struct PendingFinalize {
pub(super) encoder_output: Array,
pub(super) fallback_consumed: bool,
}
#[derive(Debug, IsVariant)]
pub(super) enum RetryStage {
StopMelFlush,
StopEncoderFeed(Array),
DecodeOwed,
StopPartialDecode(Option<Array>),
}
#[derive(Debug)]
pub(super) struct SessionRetryState {
resume_at: Option<RetryStage>,
finalize_queue: VecDeque<PendingFinalize>,
}
impl Default for SessionRetryState {
fn default() -> Self {
Self::new()
}
}
impl SessionRetryState {
pub(super) fn new() -> Self {
Self {
resume_at: None,
finalize_queue: VecDeque::new(),
}
}
#[inline(always)]
pub(super) fn has_obligation(&self) -> bool {
self.resume_at.is_some() || !self.finalize_queue.is_empty()
}
#[inline(always)]
pub(super) fn resume_at(&self) -> Option<&RetryStage> {
self.resume_at.as_ref()
}
#[inline(always)]
pub(super) fn has_pending_stop_mel_flush(&self) -> bool {
matches!(self.resume_at, Some(RetryStage::StopMelFlush))
}
#[inline(always)]
pub(super) fn has_pending_stop_encoder_feed(&self) -> bool {
matches!(self.resume_at, Some(RetryStage::StopEncoderFeed(_)))
}
#[inline(always)]
pub(super) fn has_decode_owed(&self) -> bool {
matches!(self.resume_at, Some(RetryStage::DecodeOwed))
}
pub(super) fn finalize_queue(&self) -> &VecDeque<PendingFinalize> {
&self.finalize_queue
}
pub(super) fn finalize_queue_mut(&mut self) -> &mut VecDeque<PendingFinalize> {
&mut self.finalize_queue
}
pub(super) fn enqueue_finalize(&mut self, window: Array) {
self.finalize_queue.push_back(PendingFinalize {
encoder_output: window,
fallback_consumed: false,
});
}
pub(super) fn discharge_stop_encoder_feed<B>(
&mut self,
encoder: &mut StreamingEncoder<B>,
) -> Result<usize>
where
B: StreamingEncoderBackend,
{
let Some(RetryStage::StopEncoderFeed(mel_frames)) = self.resume_at.take() else {
return Ok(0);
};
let count = match encoder.feed(&mel_frames) {
Ok(n) => n,
Err(e) => {
self.resume_at = Some(RetryStage::StopEncoderFeed(mel_frames));
return Err(e);
}
};
if count > 0 {
self.resume_at = Some(RetryStage::DecodeOwed);
}
Ok(count)
}
pub(super) fn stage_stop_encoder_feed(&mut self, mel_frames: Array) {
self.resume_at = Some(RetryStage::StopEncoderFeed(mel_frames));
}
pub(super) fn stage_stop_mel_flush(&mut self) {
self.resume_at = Some(RetryStage::StopMelFlush);
}
pub(super) fn clear_stop_mel_flush(&mut self) {
if matches!(self.resume_at, Some(RetryStage::StopMelFlush)) {
self.resume_at = None;
}
}
pub(super) fn discharge_stop_mel_flush(
&mut self,
mel_processor: &mut IncrementalMelSpectrogram,
) -> Result<Option<Array>> {
self.discharge_stop_mel_flush_with_clone(mel_processor, Array::try_clone)
}
fn discharge_stop_mel_flush_with_clone<F>(
&mut self,
mel_processor: &mut IncrementalMelSpectrogram,
clone_fn: F,
) -> Result<Option<Array>>
where
F: FnOnce(&Array) -> Result<Array>,
{
let Some(RetryStage::StopMelFlush) = self.resume_at else {
return Ok(None);
};
self.resume_at = None;
let mel_opt = match mel_processor.flush() {
Ok(m) => m,
Err(e) => {
self.resume_at = Some(RetryStage::StopMelFlush);
return Err(e);
}
};
let Some(mel) = mel_opt else {
return Ok(None);
};
match clone_fn(&mel) {
Ok(for_obligation) => {
self.resume_at = Some(RetryStage::StopEncoderFeed(for_obligation));
Ok(Some(mel))
}
Err(e) => {
self.resume_at = Some(RetryStage::StopEncoderFeed(mel));
Err(Error::LayerKeyed(LayerKeyedPayload::new(
"StopMelFlush: failed to clone flushed mel for in-call use \
(obligation preserved as StopEncoderFeed with original payload, \
retry stop() to discharge)",
e,
)))
}
}
}
pub(super) fn arm_decode_owed(&mut self) {
self.resume_at = Some(RetryStage::DecodeOwed);
}
pub(super) fn clear_decode_owed(&mut self) {
if matches!(self.resume_at, Some(RetryStage::DecodeOwed)) {
self.resume_at = None;
}
}
pub(super) fn arm_stop_partial_decode(&mut self, audio_features: Option<Array>) {
self.resume_at = Some(RetryStage::StopPartialDecode(audio_features));
}
#[inline(always)]
pub(super) fn has_pending_stop_partial_decode(&self) -> bool {
matches!(self.resume_at, Some(RetryStage::StopPartialDecode(_)))
}
pub(super) fn take_stop_partial_decode_features(&mut self) -> Option<Option<Array>> {
if matches!(self.resume_at, Some(RetryStage::StopPartialDecode(_))) {
let Some(RetryStage::StopPartialDecode(audio_features)) = self.resume_at.take() else {
unreachable!("matches! gated the take()")
};
Some(audio_features)
} else {
None
}
}
pub(super) fn clear_all(&mut self) {
self.resume_at = None;
self.finalize_queue.clear();
}
}
#[cfg(test)]
mod tests;