use crate::model::SdkError;
use std::path::Path;
use std::sync::{Arc, RwLock};
use xybrid_core::streaming::{
PartialResult as CorePartialResult, StreamConfig as CoreStreamConfig, StreamSession,
StreamState as CoreStreamState, StreamStats as CoreStreamStats,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamState {
Idle,
Streaming,
Finalizing,
Completed,
Error,
}
impl From<CoreStreamState> for StreamState {
fn from(state: CoreStreamState) -> Self {
match state {
CoreStreamState::Idle => StreamState::Idle,
CoreStreamState::Streaming => StreamState::Streaming,
CoreStreamState::Finalizing => StreamState::Finalizing,
CoreStreamState::Completed => StreamState::Completed,
CoreStreamState::Error => StreamState::Error,
}
}
}
#[derive(Debug, Clone)]
pub struct PartialResult {
pub text: String,
pub is_stable: bool,
pub chunk_index: u64,
pub audio_duration_ms: u64,
}
impl From<CorePartialResult> for PartialResult {
fn from(result: CorePartialResult) -> Self {
Self {
text: result.text,
is_stable: result.is_stable,
chunk_index: result.chunk_sequence,
audio_duration_ms: result.audio_duration.as_millis() as u64,
}
}
}
#[derive(Debug, Clone)]
pub struct TranscriptionResult {
pub text: String,
pub duration_ms: u64,
pub chunks_processed: u64,
}
#[derive(Debug, Clone)]
pub struct StreamStats {
pub state: StreamState,
pub samples_received: u64,
pub samples_processed: u64,
pub chunks_processed: u64,
pub transcript_length: usize,
pub audio_duration_ms: u64,
}
impl From<CoreStreamStats> for StreamStats {
fn from(stats: CoreStreamStats) -> Self {
Self {
state: stats.state.into(),
samples_received: stats.samples_received,
samples_processed: stats.samples_processed,
chunks_processed: stats.chunks_processed,
transcript_length: stats.transcript_length,
audio_duration_ms: stats.audio_duration.as_millis() as u64,
}
}
}
struct StreamHandle {
session: StreamSession,
model_id: String,
}
pub struct XybridStream {
handle: Arc<RwLock<StreamHandle>>,
}
impl XybridStream {
pub(crate) fn new<P: AsRef<Path>>(
model_dir: P,
config: CoreStreamConfig,
model_id: &str,
) -> Result<Self, SdkError> {
let session = StreamSession::new(model_dir, config)
.map_err(|e| SdkError::LoadError(format!("Failed to create stream session: {}", e)))?;
Ok(Self {
handle: Arc::new(RwLock::new(StreamHandle {
session,
model_id: model_id.to_string(),
})),
})
}
pub fn state(&self) -> StreamState {
self.handle
.read()
.map(|h| h.session.state().into())
.unwrap_or(StreamState::Error)
}
pub fn stats(&self) -> StreamStats {
self.handle
.read()
.map(|h| h.session.stats().into())
.unwrap_or_else(|_| StreamStats {
state: StreamState::Error,
samples_received: 0,
samples_processed: 0,
chunks_processed: 0,
transcript_length: 0,
audio_duration_ms: 0,
})
}
pub fn has_vad(&self) -> bool {
self.handle
.read()
.map(|h| h.session.has_vad())
.unwrap_or(false)
}
pub fn model_id(&self) -> String {
self.handle
.read()
.map(|h| h.model_id.clone())
.unwrap_or_default()
}
pub fn feed(&self, samples: &[f32]) -> Result<Option<PartialResult>, SdkError> {
let mut handle = self
.handle
.write()
.map_err(|_| SdkError::InferenceError("Failed to acquire stream lock".to_string()))?;
handle
.session
.feed(samples)
.map_err(|e| SdkError::InferenceError(format!("Feed failed: {}", e)))?;
Ok(handle.session.partial_result().map(|p| p.into()))
}
pub fn partial_result(&self) -> Option<PartialResult> {
self.handle
.read()
.ok()
.and_then(|h| h.session.partial_result().map(|p| p.into()))
}
pub fn flush(&self) -> Result<TranscriptionResult, SdkError> {
let mut handle = self
.handle
.write()
.map_err(|_| SdkError::InferenceError("Failed to acquire stream lock".to_string()))?;
let text = handle
.session
.flush()
.map_err(|e| SdkError::InferenceError(format!("Flush failed: {}", e)))?;
let stats = handle.session.stats();
Ok(TranscriptionResult {
text,
duration_ms: stats.audio_duration.as_millis() as u64,
chunks_processed: stats.chunks_processed,
})
}
pub fn reset(&self) -> Result<(), SdkError> {
let mut handle = self
.handle
.write()
.map_err(|_| SdkError::InferenceError("Failed to acquire stream lock".to_string()))?;
handle.session.reset();
Ok(())
}
}
impl Clone for XybridStream {
fn clone(&self) -> Self {
Self {
handle: self.handle.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_stream_state_conversion() {
assert_eq!(StreamState::from(CoreStreamState::Idle), StreamState::Idle);
assert_eq!(
StreamState::from(CoreStreamState::Streaming),
StreamState::Streaming
);
assert_eq!(
StreamState::from(CoreStreamState::Completed),
StreamState::Completed
);
}
#[test]
fn test_partial_result_conversion() {
let core = CorePartialResult {
text: "hello".to_string(),
confidence: Some(0.9),
is_stable: true,
audio_duration: Duration::from_millis(1500),
chunk_sequence: 5,
};
let result: PartialResult = core.into();
assert_eq!(result.text, "hello");
assert!(result.is_stable);
assert_eq!(result.chunk_index, 5);
assert_eq!(result.audio_duration_ms, 1500);
}
}